first commit
This commit is contained in:
879
data_representation/encoding_utils.py
Normal file
879
data_representation/encoding_utils.py
Normal file
@ -0,0 +1,879 @@
|
||||
from typing import Any
|
||||
from fractions import Fraction
|
||||
from collections import defaultdict
|
||||
|
||||
from miditoolkit import TimeSignature
|
||||
|
||||
from constants import *
|
||||
|
||||
'''
|
||||
This script contains specific encoding functions for different encoding schemes.
|
||||
'''
|
||||
|
||||
def frange(start, stop, step):
|
||||
while start < stop:
|
||||
yield start
|
||||
start += step
|
||||
|
||||
################################# for REMI style encoding #################################
|
||||
|
||||
class Corpus2event_remi():
|
||||
def __init__(self, num_features:int):
|
||||
self.num_features = num_features
|
||||
|
||||
def _create_event(self, name, value):
|
||||
event = dict()
|
||||
event['name'] = name
|
||||
event['value'] = value
|
||||
return event
|
||||
def _break_down_numerator(self, numerator, possible_time_signatures):
|
||||
"""Break down a numerator into smaller time signatures.
|
||||
|
||||
Args:
|
||||
numerator: Target numerator to decompose (must be > 0).
|
||||
possible_time_signatures: List of (numerator, denominator) tuples,
|
||||
sorted in descending order (e.g., [(4,4), (3,4)]).
|
||||
|
||||
Returns:
|
||||
List of decomposed time signatures (e.g., [(4,4), (3,4)]).
|
||||
|
||||
Raises:
|
||||
ValueError: If decomposition is impossible.
|
||||
"""
|
||||
if numerator <= 0:
|
||||
raise ValueError("Numerator must be positive.")
|
||||
if not possible_time_signatures:
|
||||
raise ValueError("No possible time signatures provided.")
|
||||
|
||||
result = []
|
||||
original_numerator = numerator # For error message
|
||||
|
||||
# Sort signatures in descending order to prioritize larger chunks
|
||||
possible_time_signatures = sorted(possible_time_signatures, key=lambda x: -x[0])
|
||||
|
||||
while numerator > 0:
|
||||
subtracted = False # Track if any subtraction occurred in this iteration
|
||||
|
||||
for sig in possible_time_signatures:
|
||||
sig_numerator, _ = sig
|
||||
if sig_numerator <= 0:
|
||||
continue # Skip invalid signatures
|
||||
|
||||
while numerator >= sig_numerator:
|
||||
result.append(sig)
|
||||
numerator -= sig_numerator
|
||||
subtracted = True
|
||||
|
||||
# If no progress was made, decomposition failed
|
||||
if not subtracted:
|
||||
raise ValueError(
|
||||
f"Cannot decompose numerator {original_numerator} "
|
||||
f"with given time signatures {possible_time_signatures}. "
|
||||
f"Remaining: {numerator}"
|
||||
)
|
||||
|
||||
return result
|
||||
def _normalize_time_signature(self, time_signature, ticks_per_beat, next_change_point):
|
||||
"""
|
||||
Normalize irregular time signatures to standard ones by breaking them down
|
||||
into common time signatures, and adjusting their durations to fit the given
|
||||
musical structure.
|
||||
|
||||
Parameters:
|
||||
- time_signature: TimeSignature object with numerator, denominator, and start time.
|
||||
- ticks_per_beat: Number of ticks per beat, representing the resolution of the timing.
|
||||
- next_change_point: Tick position where the next time signature change occurs.
|
||||
|
||||
Returns:
|
||||
- A list of TimeSignature objects, normalized to fit within regular time signatures.
|
||||
|
||||
Procedure:
|
||||
1. If the time signature is already a standard one (in REGULAR_NUM_DENOM), return it.
|
||||
2. For non-standard signatures, break them down into simpler, well-known signatures.
|
||||
- For unusual denominations (e.g., 16th, 32nd, or 64th notes), normalize to 4/4.
|
||||
- For 6/4 signatures, break it into two 3/4 measures.
|
||||
3. If the time signature has a non-standard numerator and denominator, break it down
|
||||
into the largest possible numerators that still fit within the denominator.
|
||||
This ensures that the final measure fits within the regular time signature format.
|
||||
4. Calculate the resolution (duration in ticks) for each bar and ensure the bars
|
||||
fit within the time until the next change point.
|
||||
- Adjust the number of bars if they exceed the available space.
|
||||
- If the total length is too short, repeat the first (largest) bar to fill the gap.
|
||||
5. Convert the breakdown into TimeSignature objects and return the normalized result.
|
||||
"""
|
||||
|
||||
# Check if the time signature is a regular one, return it if so
|
||||
if (time_signature.numerator, time_signature.denominator) in REGULAR_NUM_DENOM:
|
||||
return [time_signature]
|
||||
|
||||
# Extract time signature components
|
||||
numerator, denominator, bar_start_tick = time_signature.numerator, time_signature.denominator, time_signature.time
|
||||
|
||||
# Normalize time signatures with 16th, 32nd, or 64th note denominators to 4/4
|
||||
if denominator in [16, 32, 64]:
|
||||
return [TimeSignature(4, 4, time_signature.time)]
|
||||
|
||||
# Special case for 6/4, break it into two 3/4 bars
|
||||
elif denominator == 6 and numerator == 4:
|
||||
return [TimeSignature(3, 4, time_signature.time), TimeSignature(3, 4, time_signature.time)]
|
||||
|
||||
# Determine possible regular signatures for the given denominator
|
||||
possible_time_signatures = [sig for sig in CORE_NUM_DENOM if sig[1] == denominator]
|
||||
|
||||
# Sort by numerator in descending order to prioritize larger numerators
|
||||
possible_time_signatures.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
result = []
|
||||
|
||||
# Break down the numerator into smaller regular numerators
|
||||
max_iterations = 100 # Prevent infinite loops
|
||||
original_numerator = numerator # Store original for error message
|
||||
|
||||
# Break down the numerator into smaller regular numerators
|
||||
iteration_count = 0
|
||||
while numerator > 0:
|
||||
iteration_count += 1
|
||||
if iteration_count > max_iterations:
|
||||
raise ValueError(
|
||||
f"Failed to normalize time signature {original_numerator}/{denominator}. "
|
||||
f"Could not break down numerator {original_numerator} with available signatures: "
|
||||
f"{possible_time_signatures}"
|
||||
)
|
||||
|
||||
for sig in possible_time_signatures:
|
||||
# Subtract numerators and add to the result
|
||||
while numerator >= sig[0]:
|
||||
result.append(sig)
|
||||
numerator -= sig[0]
|
||||
|
||||
|
||||
|
||||
# Calculate the resolution (length in ticks) of each bar
|
||||
bar_resol_list = [int(ticks_per_beat * numerator * (4 / denominator)) for numerator, denominator in result]
|
||||
|
||||
# Adjust bars to fit within the remaining ticks before the next change point
|
||||
total_length = 0
|
||||
for idx, bar_resol in enumerate(bar_resol_list):
|
||||
total_length += bar_resol
|
||||
if total_length > next_change_point - bar_start_tick:
|
||||
result = result[:idx+1]
|
||||
break
|
||||
|
||||
# If the total length is too short, repeat the first (largest) bar until the gap is filled
|
||||
while total_length < next_change_point - bar_start_tick:
|
||||
result.append(result[0])
|
||||
total_length += int(ticks_per_beat * result[0][0] * (4 / result[0][1]))
|
||||
|
||||
# Recalculate bar resolutions for the final result
|
||||
bar_resol_list = [int(ticks_per_beat * numerator * (4 / denominator)) for numerator, denominator in result]
|
||||
|
||||
# Insert a starting resolution of 0 and calculate absolute tick positions for each TimeSignature
|
||||
bar_resol_list.insert(0, 0)
|
||||
total_length = bar_start_tick
|
||||
normalized_result = []
|
||||
for sig, length in zip(result, bar_resol_list):
|
||||
total_length += length
|
||||
normalized_result.append(TimeSignature(sig[0], sig[1], total_length))
|
||||
|
||||
return normalized_result
|
||||
|
||||
def _process_time_signature(self, time_signature_changes, ticks_per_beat, first_note_tick, global_end):
|
||||
"""
|
||||
Process and normalize time signature changes for a given musical piece.
|
||||
|
||||
Parameters:
|
||||
- time_signature_changes: A list of TimeSignature objects representing time signature changes in the music.
|
||||
- ticks_per_beat: The resolution of timing in ticks per beat.
|
||||
- first_note_tick: The tick position of the first note in the piece.
|
||||
- global_end: The tick position where the piece ends.
|
||||
|
||||
Returns:
|
||||
- A list of processed and normalized time signature changes. If no valid time signature
|
||||
changes are found, returns None.
|
||||
|
||||
Procedure:
|
||||
1. Check the validity of the time signature changes:
|
||||
- Ensure there is at least one time signature change.
|
||||
- Ensure the first time signature change occurs at the beginning (before the first note).
|
||||
2. Remove duplicate consecutive time signatures:
|
||||
- Only add time signatures that differ from the previous one (de-duplication).
|
||||
3. Normalize the time signatures:
|
||||
- For each time signature, determine its duration by calculating the time until the
|
||||
next change point or the end of the piece.
|
||||
- Use the _normalize_time_signature method to break down non-standard signatures into
|
||||
simpler, well-known signatures that fit within the musical structure.
|
||||
4. Return the processed and normalized time signature changes.
|
||||
|
||||
"""
|
||||
|
||||
# Check if there are any time signature changes
|
||||
if len(time_signature_changes) == 0:
|
||||
print("No time signature change in this tune, default to 4/4 time signature")
|
||||
# default to 4/4 time signature if none are found
|
||||
return [TimeSignature(4, 4, 0)]
|
||||
|
||||
# Ensure the first time signature change is at the start of the piece (before the first note)
|
||||
if time_signature_changes[0].time != 0 and time_signature_changes[0].time > first_note_tick:
|
||||
print("The first time signature change is not at the beginning of the tune")
|
||||
return None
|
||||
|
||||
# Remove consecutive duplicate time signatures (de-duplication)
|
||||
processed_time_signature_changes = []
|
||||
for idx, time_sig in enumerate(time_signature_changes):
|
||||
if idx == 0:
|
||||
processed_time_signature_changes.append(time_sig)
|
||||
else:
|
||||
prev_time_sig = time_signature_changes[idx-1]
|
||||
# Only add time signature if it's different from the previous one
|
||||
if not (prev_time_sig.numerator == time_sig.numerator and prev_time_sig.denominator == time_sig.denominator):
|
||||
processed_time_signature_changes.append(time_sig)
|
||||
|
||||
# Normalize the time signatures to standard formats
|
||||
normalized_time_signature_changes = []
|
||||
for idx, time_signature in enumerate(processed_time_signature_changes):
|
||||
if idx == len(time_signature_changes) - 1:
|
||||
# If it's the last time signature change, set the next change point as the end of the piece
|
||||
next_change_point = global_end
|
||||
else:
|
||||
# Otherwise, set the next change point as the next time signature's start time
|
||||
next_change_point = time_signature_changes[idx+1].time
|
||||
|
||||
# Normalize the current time signature and extend the result
|
||||
normalized_time_signature_changes.extend(self._normalize_time_signature(time_signature, ticks_per_beat, next_change_point))
|
||||
|
||||
# Return the list of processed and normalized time signatures
|
||||
time_signature_changes = normalized_time_signature_changes
|
||||
return time_signature_changes
|
||||
|
||||
def _half_step_interval_gap_check_across_instruments(self, instrument_note_dict):
|
||||
'''
|
||||
This function checks for half-step interval gaps between notes across different instruments.
|
||||
It will avoid half-step intervals by keeping one note from any pair of notes that are a half-step apart,
|
||||
regardless of which instrument they belong to.
|
||||
'''
|
||||
# order instrument_note_dict by pitch in descending order
|
||||
instrument_note_dict = dict(sorted(instrument_note_dict.items()))
|
||||
|
||||
# Create a dictionary to store all pitches across instruments
|
||||
all_pitches = {}
|
||||
|
||||
# Collect all pitches from each instrument and sort them in descending order
|
||||
for instrument, notes in instrument_note_dict.items():
|
||||
for pitch, durations in notes.items():
|
||||
all_pitches[pitch] = all_pitches.get(pitch, []) + [(instrument, durations)]
|
||||
|
||||
# Sort the pitches in descending order
|
||||
sorted_pitches = sorted(all_pitches.keys(), reverse=True)
|
||||
|
||||
# Create a new list to store the final pitches after comparison
|
||||
final_pitch_list = []
|
||||
|
||||
# Use an index pointer to control the sliding window
|
||||
idx = 0
|
||||
while idx < len(sorted_pitches) - 1:
|
||||
current_pitch = sorted_pitches[idx]
|
||||
next_pitch = sorted_pitches[idx + 1]
|
||||
|
||||
if current_pitch - next_pitch == 1: # Check for a half-step interval gap
|
||||
current_max_duration = max(duration for _, durations in all_pitches[current_pitch] for duration, _ in durations)
|
||||
next_max_duration = max(duration for _, durations in all_pitches[next_pitch] for duration, _ in durations)
|
||||
|
||||
if current_max_duration < next_max_duration:
|
||||
# Keep the higher pitch (next_pitch) and skip the current_pitch
|
||||
final_pitch_list.append(next_pitch)
|
||||
else:
|
||||
# Keep the lower pitch (current_pitch) and skip the next_pitch
|
||||
final_pitch_list.append(current_pitch)
|
||||
|
||||
# Skip the next pitch because we already handled it
|
||||
idx += 2
|
||||
else:
|
||||
# No half-step gap, keep the current pitch and move to the next one
|
||||
final_pitch_list.append(current_pitch)
|
||||
idx += 1
|
||||
|
||||
# Ensure the last pitch is added if it's not part of a half-step interval
|
||||
if idx == len(sorted_pitches) - 1:
|
||||
final_pitch_list.append(sorted_pitches[-1])
|
||||
|
||||
# Filter out notes not in the final pitch list and update the instrument_note_dict
|
||||
for instrument in instrument_note_dict.keys():
|
||||
instrument_note_dict[instrument] = {
|
||||
pitch: instrument_note_dict[instrument][pitch]
|
||||
for pitch in sorted(instrument_note_dict[instrument].keys(), reverse=True) if pitch in final_pitch_list
|
||||
}
|
||||
|
||||
return instrument_note_dict
|
||||
|
||||
def __call__(self, song_data, in_beat_resolution):
|
||||
'''
|
||||
Process a song's data to generate a sequence of musical events, including bars, chords, tempo,
|
||||
and notes, similar to the approach used in the CP paper (corpus2event_remi_v2).
|
||||
|
||||
Parameters:
|
||||
- song_data: A dictionary containing metadata, notes, chords, and tempos of the song.
|
||||
- in_beat_resolution: The resolution of timing in beats (how many divisions per beat).
|
||||
|
||||
Returns:
|
||||
- A sequence of musical events including start (SOS), bars, chords, tempo, instruments, notes,
|
||||
and an end (EOS) event. If the time signature is invalid, returns None.
|
||||
|
||||
Procedure:
|
||||
1. **Global Setup**:
|
||||
- Extract global metadata like first and last note ticks, time signature changes, and ticks
|
||||
per beat.
|
||||
- Compute `in_beat_tick_resol`, the ratio of ticks per beat to the input beat resolution,
|
||||
to assist in dividing bars later.
|
||||
- Get a sorted list of instruments in the song.
|
||||
|
||||
2. **Time Signature Processing**:
|
||||
- Call `_process_time_signature` to clean up and normalize the time signatures in the song.
|
||||
- If the time signatures are invalid (e.g., no time signature changes or missing at the
|
||||
start), the function exits early with None.
|
||||
|
||||
3. **Sequence Generation**:
|
||||
- Initialize the sequence with a start token (SOS) and prepare variables for tracking
|
||||
previous chord, tempo, and instrument states.
|
||||
- Loop through each time signature change, dividing the song into measures based on the
|
||||
current time signature's numerator and denominator.
|
||||
- For each measure, append "Bar" tokens to mark measure boundaries, while ensuring that no
|
||||
more than four consecutive empty bars are added.
|
||||
- For each step within a measure, process the following:
|
||||
- **Chords**: If there is a chord change, add a corresponding chord event.
|
||||
- **Tempo**: If the tempo changes, add a tempo event.
|
||||
- **Notes**: Iterate over each instrument, adding notes and checking for half-step
|
||||
intervals, deduplicating notes, and choosing the longest duration for each pitch.
|
||||
- Append a "Beat" event for each step with musical events.
|
||||
|
||||
4. **End Sequence**:
|
||||
- Conclude the sequence by appending a final "Bar" token followed by an end token (EOS).
|
||||
'''
|
||||
|
||||
# --- global tag --- #
|
||||
first_note_tick = song_data['metadata']['first_note'] # Starting tick of the first note
|
||||
global_end = song_data['metadata']['last_note'] # Ending tick of the last note
|
||||
time_signature_changes = song_data['metadata']['time_signature'] # Time signature changes
|
||||
ticks_per_beat = song_data['metadata']['ticks_per_beat'] # Ticks per beat resolution
|
||||
# Resolution for dividing beats within measures, expressed as a fraction
|
||||
in_beat_tick_resol = Fraction(ticks_per_beat, in_beat_resolution) # Example: 1024/12 -> (256, 3)
|
||||
instrument_list = sorted(list(song_data['notes'].keys())) # Get a sorted list of instruments in the song
|
||||
|
||||
# --- process time signature --- #
|
||||
# Normalize and process the time signatures in the song
|
||||
time_signature_changes = self._process_time_signature(time_signature_changes, ticks_per_beat, first_note_tick, global_end)
|
||||
if time_signature_changes == None:
|
||||
return None # Exit if time signature is invalid
|
||||
|
||||
# --- create sequence --- #
|
||||
prev_instr_idx = None # Track the previously processed instrument
|
||||
final_sequence = []
|
||||
final_sequence.append(self._create_event('SOS', None)) # Add Start of Sequence (SOS) token
|
||||
prev_chord = None # Track the previous chord
|
||||
prev_tempo = None # Track the previous tempo
|
||||
chord_value = None
|
||||
tempo_value = None
|
||||
|
||||
# Process each time signature change
|
||||
for idx in range(len(time_signature_changes)):
|
||||
time_sig_change_flag = True # Flag to indicate a time signature change
|
||||
# Calculate bar resolution based on the current time signature
|
||||
numerator = time_signature_changes[idx].numerator
|
||||
denominator = time_signature_changes[idx].denominator
|
||||
time_sig_name = f'time_signature_{numerator}/{denominator}' # Format time signature name
|
||||
bar_resol = int(ticks_per_beat * numerator * (4 / denominator)) # Calculate bar resolution in ticks
|
||||
bar_start_tick = time_signature_changes[idx].time # Start tick of the current bar
|
||||
# Determine the next time signature change point or the end of the song
|
||||
if idx == len(time_signature_changes) - 1:
|
||||
next_change_point = global_end
|
||||
else:
|
||||
next_change_point = time_signature_changes[idx+1].time
|
||||
|
||||
# Process each measure within the current time signature
|
||||
for measure_step in frange(bar_start_tick, next_change_point, bar_resol):
|
||||
empty_bar_token = self._create_event('Bar', None) # Token for empty bars
|
||||
|
||||
# Ensure no more than 4 consecutive empty bars are added
|
||||
if len(final_sequence) >= 4:
|
||||
if not (final_sequence[-1] == empty_bar_token and final_sequence[-2] == empty_bar_token and
|
||||
final_sequence[-3] == empty_bar_token and final_sequence[-4] == empty_bar_token):
|
||||
if time_sig_change_flag:
|
||||
final_sequence.append(self._create_event('Bar', time_sig_name)) # Mark new bar with time signature
|
||||
else:
|
||||
final_sequence.append(self._create_event('Bar', None))
|
||||
else:
|
||||
if time_sig_change_flag:
|
||||
final_sequence.append(self._create_event('Bar', time_sig_name))
|
||||
else:
|
||||
if time_sig_change_flag:
|
||||
final_sequence.append(self._create_event('Bar', time_sig_name))
|
||||
else:
|
||||
final_sequence.append(self._create_event('Bar', None))
|
||||
|
||||
time_sig_change_flag = False # Reset time signature change flag
|
||||
|
||||
# Process events within each beat
|
||||
for in_beat_off_idx, beat_step in enumerate(frange(measure_step, measure_step + bar_resol, in_beat_tick_resol)):
|
||||
events_list = []
|
||||
# Retrieve chords and tempos at the current beat step
|
||||
t_chords = song_data['chords'].get(beat_step)
|
||||
t_tempos = song_data['tempos'].get(beat_step)
|
||||
|
||||
# Process chord and tempo if the number of features allows for it
|
||||
if self.num_features in {8, 7}:
|
||||
if t_chords is not None:
|
||||
root, quality, _ = t_chords[-1].text.split('_') # Extract chord info
|
||||
chord_value = root + '_' + quality
|
||||
if t_tempos is not None:
|
||||
tempo_value = t_tempos[-1].tempo # Extract tempo value
|
||||
|
||||
# Dictionary to track notes for each instrument to avoid duplicates
|
||||
instrument_note_dict = defaultdict(dict)
|
||||
|
||||
# Process notes for each instrument at the current beat step
|
||||
for instrument_idx in instrument_list:
|
||||
t_notes = song_data['notes'][instrument_idx].get(beat_step)
|
||||
|
||||
# If there are notes at this beat step, process them.
|
||||
if t_notes is not None:
|
||||
# Track notes to avoid duplicates and check for half-step intervals
|
||||
for note in t_notes:
|
||||
if note.pitch not in instrument_note_dict[instrument_idx]:
|
||||
instrument_note_dict[instrument_idx][note.pitch] = [(note.quantized_duration, note.velocity)]
|
||||
else:
|
||||
instrument_note_dict[instrument_idx][note.pitch].append((note.quantized_duration, note.velocity))
|
||||
|
||||
if len(instrument_note_dict) == 0:
|
||||
continue
|
||||
|
||||
# Check for half-step interval gaps and handle them across instruments
|
||||
pruned_instrument_note_dict = self._half_step_interval_gap_check_across_instruments(instrument_note_dict)
|
||||
|
||||
# add chord and tempo
|
||||
if self.num_features in {7, 8}:
|
||||
if prev_chord != chord_value:
|
||||
events_list.append(self._create_event('Chord', chord_value))
|
||||
prev_chord = chord_value
|
||||
if prev_tempo != tempo_value:
|
||||
events_list.append(self._create_event('Tempo', tempo_value))
|
||||
prev_tempo = tempo_value
|
||||
|
||||
# add instrument and note
|
||||
for instrument in pruned_instrument_note_dict:
|
||||
if self.num_features in {5, 8}:
|
||||
events_list.append(self._create_event('Instrument', instrument))
|
||||
|
||||
for pitch in pruned_instrument_note_dict[instrument]:
|
||||
max_duration = max(pruned_instrument_note_dict[instrument][pitch], key=lambda x: x[0])
|
||||
note_event = [
|
||||
self._create_event('Note_Pitch', pitch),
|
||||
self._create_event('Note_Duration', max_duration[0])
|
||||
]
|
||||
if self.num_features in {7, 8}:
|
||||
note_event.append(self._create_event('Note_Velocity', max_duration[1]))
|
||||
events_list.extend(note_event)
|
||||
|
||||
# If there are events in this step, add a "Beat" event and the collected events
|
||||
if len(events_list):
|
||||
final_sequence.append(self._create_event('Beat', in_beat_off_idx))
|
||||
final_sequence.extend(events_list)
|
||||
|
||||
# --- end with BAR & EOS --- #
|
||||
final_sequence.append(self._create_event('Bar', None)) # Add final bar token
|
||||
final_sequence.append(self._create_event('EOS', None)) # Add End of Sequence (EOS) token
|
||||
return final_sequence
|
||||
|
||||
################################# for CP style encoding #################################
|
||||
|
||||
class Corpus2event_cp(Corpus2event_remi):
|
||||
def __init__(self, num_features):
|
||||
super().__init__(num_features)
|
||||
self.num_features = num_features
|
||||
self._init_event_template()
|
||||
|
||||
def _init_event_template(self):
|
||||
'''
|
||||
The order of musical features is Type, Beat, Chord, Tempo, Instrument, Pitch, Duration, Velocity
|
||||
'''
|
||||
self.event_template = {}
|
||||
if self.num_features == 8:
|
||||
feature_names = ['type', 'beat', 'chord', 'tempo', 'instrument', 'pitch', 'duration', 'velocity']
|
||||
elif self.num_features == 7:
|
||||
feature_names = ['type', 'beat', 'chord', 'tempo', 'pitch', 'duration', 'velocity']
|
||||
elif self.num_features == 5:
|
||||
feature_names = ['type', 'beat', 'instrument', 'pitch', 'duration']
|
||||
elif self.num_features == 4:
|
||||
feature_names = ['type', 'beat', 'pitch', 'duration']
|
||||
for feature_name in feature_names:
|
||||
self.event_template[feature_name] = 0
|
||||
|
||||
def create_cp_sos_event(self):
|
||||
total_event = self.event_template.copy()
|
||||
total_event['type'] = 'SOS'
|
||||
return total_event
|
||||
|
||||
def create_cp_eos_event(self):
|
||||
total_event = self.event_template.copy()
|
||||
total_event['type'] = 'EOS'
|
||||
return total_event
|
||||
|
||||
def create_cp_metrical_event(self, pos, chord, tempo):
|
||||
'''
|
||||
when the compound token is related to metrical information
|
||||
'''
|
||||
meter_event = self.event_template.copy()
|
||||
meter_event['type'] = 'Metrical'
|
||||
meter_event['beat'] = pos
|
||||
if self.num_features == 7 or self.num_features == 8:
|
||||
meter_event['chord'] = chord
|
||||
meter_event['tempo'] = tempo
|
||||
return meter_event
|
||||
|
||||
def create_cp_note_event(self, instrument_name, pitch, duration, velocity):
|
||||
'''
|
||||
when the compound token is related to note information
|
||||
'''
|
||||
note_event = self.event_template.copy()
|
||||
note_event['type'] = 'Note'
|
||||
note_event['pitch'] = pitch
|
||||
note_event['duration'] = duration
|
||||
if self.num_features == 5 or self.num_features == 8:
|
||||
note_event['instrument'] = instrument_name
|
||||
if self.num_features == 7 or self.num_features == 8:
|
||||
note_event['velocity'] = velocity
|
||||
return note_event
|
||||
|
||||
def create_cp_bar_event(self, time_sig_change_flag=False, time_sig_name=None):
|
||||
meter_event = self.event_template.copy()
|
||||
if time_sig_change_flag:
|
||||
meter_event['type'] = 'Metrical'
|
||||
meter_event['beat'] = f'Bar_{time_sig_name}'
|
||||
else:
|
||||
meter_event['type'] = 'Metrical'
|
||||
meter_event['beat'] = 'Bar'
|
||||
return meter_event
|
||||
|
||||
def __call__(self, song_data, in_beat_resolution):
|
||||
# --- global tag --- #
|
||||
first_note_tick = song_data['metadata']['first_note'] # First note timestamp in ticks
|
||||
global_end = song_data['metadata']['last_note'] # Last note timestamp in ticks
|
||||
time_signature_changes = song_data['metadata']['time_signature'] # Time signature changes throughout the song
|
||||
ticks_per_beat = song_data['metadata']['ticks_per_beat'] # Ticks per beat (resolution of the timing grid)
|
||||
in_beat_tick_resol = Fraction(ticks_per_beat, in_beat_resolution) # Tick resolution for beats
|
||||
instrument_list = sorted(list(song_data['notes'].keys())) # List of instruments in the song
|
||||
|
||||
# --- process time signature --- #
|
||||
# Process time signature changes and adjust them for the given song structure
|
||||
time_signature_changes = self._process_time_signature(time_signature_changes, ticks_per_beat, first_note_tick, global_end)
|
||||
if time_signature_changes == None:
|
||||
return None # Exit if no valid time signature changes found
|
||||
|
||||
# --- create sequence --- #
|
||||
final_sequence = [] # Initialize the final sequence to store the events
|
||||
final_sequence.append(self.create_cp_sos_event()) # Add the Start-of-Sequence (SOS) event
|
||||
chord_text = None # Placeholder for the current chord
|
||||
tempo_text = None # Placeholder for the current tempo
|
||||
|
||||
# Loop through each time signature change and process the corresponding measures
|
||||
for idx in range(len(time_signature_changes)):
|
||||
time_sig_change_flag = True # Flag to track when time signature changes
|
||||
# Calculate bar resolution (number of ticks per bar based on the time signature)
|
||||
numerator = time_signature_changes[idx].numerator
|
||||
denominator = time_signature_changes[idx].denominator
|
||||
time_sig_name = f'time_signature_{numerator}/{denominator}' # Format the time signature as a string
|
||||
bar_resol = int(ticks_per_beat * numerator * (4 / denominator)) # Calculate number of ticks per bar
|
||||
bar_start_tick = time_signature_changes[idx].time # Starting tick for this time signature
|
||||
|
||||
# Determine the point for the next time signature change or the end of the song
|
||||
if idx == len(time_signature_changes) - 1:
|
||||
next_change_point = global_end
|
||||
else:
|
||||
next_change_point = time_signature_changes[idx + 1].time
|
||||
|
||||
# Iterate over each measure (bar) between the current and next time signature change
|
||||
for measure_step in frange(bar_start_tick, next_change_point, bar_resol):
|
||||
empty_bar_token = self.create_cp_bar_event() # Create an empty bar event
|
||||
|
||||
# Check if the last four events in the sequence are consecutive empty bars
|
||||
if len(final_sequence) >= 4:
|
||||
if not (final_sequence[-1] == empty_bar_token and final_sequence[-2] == empty_bar_token and final_sequence[-3] == empty_bar_token and final_sequence[-4] == empty_bar_token):
|
||||
final_sequence.append(self.create_cp_bar_event(time_sig_change_flag, time_sig_name))
|
||||
else:
|
||||
if time_sig_change_flag:
|
||||
final_sequence.append(self.create_cp_bar_event(time_sig_change_flag, time_sig_name))
|
||||
else:
|
||||
final_sequence.append(self.create_cp_bar_event(time_sig_change_flag, time_sig_name))
|
||||
|
||||
# Reset the time signature change flag after handling the bar event
|
||||
time_sig_change_flag = False
|
||||
|
||||
# Loop through beats in each measure based on the in-beat resolution
|
||||
for in_beat_off_idx, beat_step in enumerate(frange(measure_step, measure_step + bar_resol, in_beat_tick_resol)):
|
||||
chord_tempo_flag = False # Flag to track if chord and tempo events are added
|
||||
events_list = [] # List to hold events for the current beat
|
||||
pos_text = 'Beat_' + str(in_beat_off_idx) # Create a beat event label
|
||||
|
||||
# --- chord & tempo processing --- #
|
||||
# Unpack chords and tempos for the current beat step
|
||||
t_chords = song_data['chords'].get(beat_step)
|
||||
t_tempos = song_data['tempos'].get(beat_step)
|
||||
|
||||
# If a chord is present, extract its root, quality, and bass
|
||||
if self.num_features in {7, 8}:
|
||||
if t_chords is not None:
|
||||
root, quality, _ = t_chords[-1].text.split('_')
|
||||
chord_text = 'Chord_' + root + '_' + quality
|
||||
|
||||
# If a tempo is present, format it as a string
|
||||
if t_tempos is not None:
|
||||
tempo_text = 'Tempo_' + str(t_tempos[-1].tempo)
|
||||
|
||||
# Dictionary to track notes for each instrument to avoid duplicates
|
||||
instrument_note_dict = defaultdict(dict)
|
||||
|
||||
# --- instrument & note processing --- #
|
||||
# Loop through each instrument and process its notes at the current beat step
|
||||
for instrument_idx in instrument_list:
|
||||
t_notes = song_data['notes'][instrument_idx].get(beat_step)
|
||||
|
||||
# If notes are present, process them
|
||||
if t_notes != None:
|
||||
# Track notes and their properties (duration and velocity) for the current instrument
|
||||
for note in t_notes:
|
||||
if note.pitch not in instrument_note_dict[instrument_idx]:
|
||||
instrument_note_dict[instrument_idx][note.pitch] = [(note.quantized_duration, note.velocity)]
|
||||
else:
|
||||
instrument_note_dict[instrument_idx][note.pitch].append((note.quantized_duration, note.velocity))
|
||||
|
||||
if len(instrument_note_dict) == 0:
|
||||
continue
|
||||
|
||||
# Check for half-step interval gaps and handle them across instruments
|
||||
pruned_instrument_note_dict = self._half_step_interval_gap_check_across_instruments(instrument_note_dict)
|
||||
|
||||
# add chord and tempo
|
||||
if self.num_features in {7, 8}:
|
||||
if not chord_tempo_flag:
|
||||
if chord_text == None:
|
||||
chord_text = 'Chord_N_N'
|
||||
if tempo_text == None:
|
||||
tempo_text = 'Tempo_N_N'
|
||||
chord_tempo_flag = True
|
||||
|
||||
events_list.append(self.create_cp_metrical_event(pos_text, chord_text, tempo_text))
|
||||
|
||||
# add instrument and note
|
||||
for instrument_idx in pruned_instrument_note_dict:
|
||||
instrument_name = 'Instrument_' + str(instrument_idx)
|
||||
for pitch in pruned_instrument_note_dict[instrument_idx]:
|
||||
max_duration = max(pruned_instrument_note_dict[instrument_idx][pitch], key=lambda x: x[0])
|
||||
note_pitch_text = 'Note_Pitch_' + str(pitch)
|
||||
note_duration_text = 'Note_Duration_' + str(max_duration[0])
|
||||
note_velocity_text = 'Note_Velocity_' + str(max_duration[1])
|
||||
events_list.append(self.create_cp_note_event(instrument_name, note_pitch_text, note_duration_text, note_velocity_text))
|
||||
|
||||
# If there are any events for this beat, add them to the final sequence
|
||||
if len(events_list) > 0:
|
||||
final_sequence.extend(events_list)
|
||||
|
||||
# --- end with BAR & EOS --- #
|
||||
final_sequence.append(self.create_cp_bar_event()) # Add the final bar event
|
||||
final_sequence.append(self.create_cp_eos_event()) # Add the End-of-Sequence (EOS) event
|
||||
return final_sequence # Return the final sequence of events
|
||||
|
||||
################################# for NB style encoding #################################
|
||||
|
||||
class Corpus2event_nb(Corpus2event_cp):
|
||||
def __init__(self, num_features):
|
||||
'''
|
||||
For convenience in logging, we use "type" word for "metric" sub-token in the code to compare easily with other encoding schemes
|
||||
'''
|
||||
super().__init__(num_features)
|
||||
self.num_features = num_features
|
||||
self._init_event_template()
|
||||
|
||||
def _init_event_template(self):
|
||||
self.event_template = {}
|
||||
if self.num_features == 8:
|
||||
feature_names = ['type', 'beat', 'chord', 'tempo', 'instrument', 'pitch', 'duration', 'velocity']
|
||||
elif self.num_features == 7:
|
||||
feature_names = ['type', 'beat', 'chord', 'tempo', 'pitch', 'duration', 'velocity']
|
||||
elif self.num_features == 5:
|
||||
feature_names = ['type', 'beat', 'instrument', 'pitch', 'duration']
|
||||
elif self.num_features == 4:
|
||||
feature_names = ['type', 'beat', 'pitch', 'duration']
|
||||
for feature_name in feature_names:
|
||||
self.event_template[feature_name] = 0
|
||||
|
||||
def create_nb_sos_event(self):
|
||||
total_event = self.event_template.copy()
|
||||
total_event['type'] = 'SOS'
|
||||
return total_event
|
||||
|
||||
def create_nb_eos_event(self):
|
||||
total_event = self.event_template.copy()
|
||||
total_event['type'] = 'EOS'
|
||||
return total_event
|
||||
|
||||
def create_nb_event(self, bar_beat_type, pos, chord, tempo, instrument_name, pitch, duration, velocity):
|
||||
total_event = self.event_template.copy()
|
||||
total_event['type'] = bar_beat_type
|
||||
total_event['beat'] = pos
|
||||
total_event['pitch'] = pitch
|
||||
total_event['duration'] = duration
|
||||
if self.num_features in {5, 8}:
|
||||
total_event['instrument'] = instrument_name
|
||||
if self.num_features in {7, 8}:
|
||||
total_event['chord'] = chord
|
||||
total_event['tempo'] = tempo
|
||||
total_event['velocity'] = velocity
|
||||
return total_event
|
||||
|
||||
def create_nb_empty_bar_event(self):
|
||||
total_event = self.event_template.copy()
|
||||
total_event['type'] = 'Empty_Bar'
|
||||
return total_event
|
||||
|
||||
def get_bar_beat_idx(self, bar_flag, beat_flag, time_sig_name, time_sig_change_flag):
|
||||
'''
|
||||
This function is to get the metric information for the current bar and beat
|
||||
There are four types of metric information: NNN, SNN, SSN, SSS
|
||||
Each letter represents the change of time signature, bar, and beat (new or same)
|
||||
'''
|
||||
if time_sig_change_flag: # new time signature
|
||||
return "NNN_" + time_sig_name
|
||||
else:
|
||||
if bar_flag and beat_flag: # same time sig & new bar & new beat
|
||||
return "SNN"
|
||||
elif not bar_flag and beat_flag: # same time sig & same bar & new beat
|
||||
return "SSN"
|
||||
elif not bar_flag and not beat_flag: # same time sig & same bar & same beat
|
||||
return "SSS"
|
||||
|
||||
def __call__(self, song_data, in_beat_resolution:int):
|
||||
# --- global tag --- #
|
||||
first_note_tick = song_data['metadata']['first_note'] # First note timestamp in ticks
|
||||
global_end = song_data['metadata']['last_note'] # Last note timestamp in ticks
|
||||
time_signature_changes = song_data['metadata']['time_signature'] # Time signature changes throughout the song
|
||||
ticks_per_beat = song_data['metadata']['ticks_per_beat'] # Ticks per beat (resolution of the timing grid)
|
||||
in_beat_tick_resol = Fraction(ticks_per_beat, in_beat_resolution) # Tick resolution for beats
|
||||
instrument_list = sorted(list(song_data['notes'].keys())) # List of instruments in the song
|
||||
|
||||
# --- process time signature --- #
|
||||
# Process time signature changes and adjust them for the given song structure
|
||||
time_signature_changes = self._process_time_signature(time_signature_changes, ticks_per_beat, first_note_tick, global_end)
|
||||
if time_signature_changes == None:
|
||||
return None # Exit if no valid time signature changes found
|
||||
|
||||
# --- create sequence --- #
|
||||
final_sequence = [] # Initialize the final sequence to store the events
|
||||
final_sequence.append(self.create_nb_sos_event()) # Add the Start-of-Sequence (SOS) event
|
||||
chord_text = None # Placeholder for the current chord
|
||||
tempo_text = None # Placeholder for the current tempo
|
||||
|
||||
# Loop through each time signature change and process the corresponding measures
|
||||
for idx in range(len(time_signature_changes)):
|
||||
time_sig_change_flag = True # Flag to track when time signature changes
|
||||
# Calculate bar resolution (number of ticks per bar based on the time signature)
|
||||
numerator = time_signature_changes[idx].numerator
|
||||
denominator = time_signature_changes[idx].denominator
|
||||
time_sig_name = f'time_signature_{numerator}/{denominator}' # Format the time signature as a string
|
||||
bar_resol = int(ticks_per_beat * numerator * (4 / denominator)) # Calculate number of ticks per bar
|
||||
bar_start_tick = time_signature_changes[idx].time # Starting tick for this time signature
|
||||
|
||||
# Determine the point for the next time signature change or the end of the song
|
||||
if idx == len(time_signature_changes) - 1:
|
||||
next_change_point = global_end
|
||||
else:
|
||||
next_change_point = time_signature_changes[idx + 1].time
|
||||
|
||||
# Iterate over each measure (bar) between the current and next time signature change
|
||||
for measure_step in frange(bar_start_tick, next_change_point, bar_resol):
|
||||
bar_flag = True
|
||||
note_flag = False
|
||||
|
||||
# Loop through beats in each measure based on the in-beat resolution
|
||||
for in_beat_off_idx, beat_step in enumerate(frange(measure_step, measure_step + bar_resol, in_beat_tick_resol)):
|
||||
beat_flag = True
|
||||
events_list = []
|
||||
pos_text = 'Beat_' + str(in_beat_off_idx)
|
||||
|
||||
# --- chord & tempo processing --- #
|
||||
# Unpack chords and tempos for the current beat step
|
||||
t_chords = song_data['chords'].get(beat_step)
|
||||
t_tempos = song_data['tempos'].get(beat_step)
|
||||
|
||||
# If a chord is present, extract its root, quality, and bass
|
||||
if self.num_features == 8 or self.num_features == 7:
|
||||
if t_chords is not None:
|
||||
root, quality, _ = t_chords[-1].text.split('_')
|
||||
chord_text = 'Chord_' + root + '_' + quality
|
||||
|
||||
# If a tempo is present, format it as a string
|
||||
if t_tempos is not None:
|
||||
tempo_text = 'Tempo_' + str(t_tempos[-1].tempo)
|
||||
|
||||
# Dictionary to track notes for each instrument to avoid duplicates
|
||||
instrument_note_dict = defaultdict(dict)
|
||||
|
||||
# --- instrument & note processing --- #
|
||||
# Loop through each instrument and process its notes at the current beat step
|
||||
for instrument_idx in instrument_list:
|
||||
t_notes = song_data['notes'][instrument_idx].get(beat_step)
|
||||
|
||||
# If notes are present, process them
|
||||
if t_notes != None:
|
||||
note_flag = True
|
||||
|
||||
# Track notes and their properties (duration and velocity) for the current instrument
|
||||
for note in t_notes:
|
||||
if note.pitch not in instrument_note_dict[instrument_idx]:
|
||||
instrument_note_dict[instrument_idx][note.pitch] = [(note.quantized_duration, note.velocity)]
|
||||
else:
|
||||
instrument_note_dict[instrument_idx][note.pitch].append((note.quantized_duration, note.velocity))
|
||||
|
||||
# # Check for half-step interval gaps and handle them accordingly
|
||||
# self._half_step_interval_gap_check(instrument_note_dict, instrument_idx)
|
||||
|
||||
if len(instrument_note_dict) == 0:
|
||||
continue
|
||||
|
||||
# Check for half-step interval gaps and handle them across instruments
|
||||
pruned_instrument_note_dict = self._half_step_interval_gap_check_across_instruments(instrument_note_dict)
|
||||
|
||||
# add chord and tempo
|
||||
if self.num_features in {7, 8}:
|
||||
if chord_text == None:
|
||||
chord_text = 'Chord_N_N'
|
||||
if tempo_text == None:
|
||||
tempo_text = 'Tempo_N_N'
|
||||
|
||||
# add instrument and note
|
||||
for instrument_idx in pruned_instrument_note_dict:
|
||||
instrument_name = 'Instrument_' + str(instrument_idx)
|
||||
for pitch in pruned_instrument_note_dict[instrument_idx]:
|
||||
max_duration = max(pruned_instrument_note_dict[instrument_idx][pitch], key=lambda x: x[0])
|
||||
note_pitch_text = 'Note_Pitch_' + str(pitch)
|
||||
note_duration_text = 'Note_Duration_' + str(max_duration[0])
|
||||
note_velocity_text = 'Note_Velocity_' + str(max_duration[1])
|
||||
bar_beat_type = self.get_bar_beat_idx(bar_flag, beat_flag, time_sig_name, time_sig_change_flag)
|
||||
events_list.append(self.create_nb_event(bar_beat_type, pos_text, chord_text, tempo_text, instrument_name, note_pitch_text, note_duration_text, note_velocity_text))
|
||||
bar_flag = False
|
||||
beat_flag = False
|
||||
time_sig_change_flag = False
|
||||
|
||||
# If there are any events for this beat, add them to the final sequence
|
||||
if events_list != None and len(events_list):
|
||||
final_sequence.extend(events_list)
|
||||
|
||||
# when there is no note in this bar
|
||||
if not note_flag:
|
||||
# avoid consecutive empty bars (more than 4 is not allowed)
|
||||
empty_bar_token = self.create_nb_empty_bar_event()
|
||||
if len(final_sequence) >= 4:
|
||||
if final_sequence[-1] == empty_bar_token and final_sequence[-2] == empty_bar_token and final_sequence[-3] == empty_bar_token and final_sequence[-4] == empty_bar_token:
|
||||
continue
|
||||
final_sequence.append(empty_bar_token)
|
||||
|
||||
# --- end with BAR & EOS --- #
|
||||
final_sequence.append(self.create_nb_eos_event())
|
||||
return final_sequence
|
||||
Reference in New Issue
Block a user