Files
2025-09-08 14:49:28 +08:00

879 lines
46 KiB
Python

from typing import Any
from fractions import Fraction
from collections import defaultdict
from miditoolkit import TimeSignature
from constants import *
'''
This script contains specific encoding functions for different encoding schemes.
'''
def frange(start, stop, step):
while start < stop:
yield start
start += step
################################# for REMI style encoding #################################
class Corpus2event_remi():
def __init__(self, num_features:int):
self.num_features = num_features
def _create_event(self, name, value):
event = dict()
event['name'] = name
event['value'] = value
return event
def _break_down_numerator(self, numerator, possible_time_signatures):
"""Break down a numerator into smaller time signatures.
Args:
numerator: Target numerator to decompose (must be > 0).
possible_time_signatures: List of (numerator, denominator) tuples,
sorted in descending order (e.g., [(4,4), (3,4)]).
Returns:
List of decomposed time signatures (e.g., [(4,4), (3,4)]).
Raises:
ValueError: If decomposition is impossible.
"""
if numerator <= 0:
raise ValueError("Numerator must be positive.")
if not possible_time_signatures:
raise ValueError("No possible time signatures provided.")
result = []
original_numerator = numerator # For error message
# Sort signatures in descending order to prioritize larger chunks
possible_time_signatures = sorted(possible_time_signatures, key=lambda x: -x[0])
while numerator > 0:
subtracted = False # Track if any subtraction occurred in this iteration
for sig in possible_time_signatures:
sig_numerator, _ = sig
if sig_numerator <= 0:
continue # Skip invalid signatures
while numerator >= sig_numerator:
result.append(sig)
numerator -= sig_numerator
subtracted = True
# If no progress was made, decomposition failed
if not subtracted:
raise ValueError(
f"Cannot decompose numerator {original_numerator} "
f"with given time signatures {possible_time_signatures}. "
f"Remaining: {numerator}"
)
return result
def _normalize_time_signature(self, time_signature, ticks_per_beat, next_change_point):
"""
Normalize irregular time signatures to standard ones by breaking them down
into common time signatures, and adjusting their durations to fit the given
musical structure.
Parameters:
- time_signature: TimeSignature object with numerator, denominator, and start time.
- ticks_per_beat: Number of ticks per beat, representing the resolution of the timing.
- next_change_point: Tick position where the next time signature change occurs.
Returns:
- A list of TimeSignature objects, normalized to fit within regular time signatures.
Procedure:
1. If the time signature is already a standard one (in REGULAR_NUM_DENOM), return it.
2. For non-standard signatures, break them down into simpler, well-known signatures.
- For unusual denominations (e.g., 16th, 32nd, or 64th notes), normalize to 4/4.
- For 6/4 signatures, break it into two 3/4 measures.
3. If the time signature has a non-standard numerator and denominator, break it down
into the largest possible numerators that still fit within the denominator.
This ensures that the final measure fits within the regular time signature format.
4. Calculate the resolution (duration in ticks) for each bar and ensure the bars
fit within the time until the next change point.
- Adjust the number of bars if they exceed the available space.
- If the total length is too short, repeat the first (largest) bar to fill the gap.
5. Convert the breakdown into TimeSignature objects and return the normalized result.
"""
# Check if the time signature is a regular one, return it if so
if (time_signature.numerator, time_signature.denominator) in REGULAR_NUM_DENOM:
return [time_signature]
# Extract time signature components
numerator, denominator, bar_start_tick = time_signature.numerator, time_signature.denominator, time_signature.time
# Normalize time signatures with 16th, 32nd, or 64th note denominators to 4/4
if denominator in [16, 32, 64]:
return [TimeSignature(4, 4, time_signature.time)]
# Special case for 6/4, break it into two 3/4 bars
elif denominator == 6 and numerator == 4:
return [TimeSignature(3, 4, time_signature.time), TimeSignature(3, 4, time_signature.time)]
# Determine possible regular signatures for the given denominator
possible_time_signatures = [sig for sig in CORE_NUM_DENOM if sig[1] == denominator]
# Sort by numerator in descending order to prioritize larger numerators
possible_time_signatures.sort(key=lambda x: x[0], reverse=True)
result = []
# Break down the numerator into smaller regular numerators
max_iterations = 100 # Prevent infinite loops
original_numerator = numerator # Store original for error message
# Break down the numerator into smaller regular numerators
iteration_count = 0
while numerator > 0:
iteration_count += 1
if iteration_count > max_iterations:
raise ValueError(
f"Failed to normalize time signature {original_numerator}/{denominator}. "
f"Could not break down numerator {original_numerator} with available signatures: "
f"{possible_time_signatures}"
)
for sig in possible_time_signatures:
# Subtract numerators and add to the result
while numerator >= sig[0]:
result.append(sig)
numerator -= sig[0]
# Calculate the resolution (length in ticks) of each bar
bar_resol_list = [int(ticks_per_beat * numerator * (4 / denominator)) for numerator, denominator in result]
# Adjust bars to fit within the remaining ticks before the next change point
total_length = 0
for idx, bar_resol in enumerate(bar_resol_list):
total_length += bar_resol
if total_length > next_change_point - bar_start_tick:
result = result[:idx+1]
break
# If the total length is too short, repeat the first (largest) bar until the gap is filled
while total_length < next_change_point - bar_start_tick:
result.append(result[0])
total_length += int(ticks_per_beat * result[0][0] * (4 / result[0][1]))
# Recalculate bar resolutions for the final result
bar_resol_list = [int(ticks_per_beat * numerator * (4 / denominator)) for numerator, denominator in result]
# Insert a starting resolution of 0 and calculate absolute tick positions for each TimeSignature
bar_resol_list.insert(0, 0)
total_length = bar_start_tick
normalized_result = []
for sig, length in zip(result, bar_resol_list):
total_length += length
normalized_result.append(TimeSignature(sig[0], sig[1], total_length))
return normalized_result
def _process_time_signature(self, time_signature_changes, ticks_per_beat, first_note_tick, global_end):
"""
Process and normalize time signature changes for a given musical piece.
Parameters:
- time_signature_changes: A list of TimeSignature objects representing time signature changes in the music.
- ticks_per_beat: The resolution of timing in ticks per beat.
- first_note_tick: The tick position of the first note in the piece.
- global_end: The tick position where the piece ends.
Returns:
- A list of processed and normalized time signature changes. If no valid time signature
changes are found, returns None.
Procedure:
1. Check the validity of the time signature changes:
- Ensure there is at least one time signature change.
- Ensure the first time signature change occurs at the beginning (before the first note).
2. Remove duplicate consecutive time signatures:
- Only add time signatures that differ from the previous one (de-duplication).
3. Normalize the time signatures:
- For each time signature, determine its duration by calculating the time until the
next change point or the end of the piece.
- Use the _normalize_time_signature method to break down non-standard signatures into
simpler, well-known signatures that fit within the musical structure.
4. Return the processed and normalized time signature changes.
"""
# Check if there are any time signature changes
if len(time_signature_changes) == 0:
print("No time signature change in this tune, default to 4/4 time signature")
# default to 4/4 time signature if none are found
return [TimeSignature(4, 4, 0)]
# Ensure the first time signature change is at the start of the piece (before the first note)
if time_signature_changes[0].time != 0 and time_signature_changes[0].time > first_note_tick:
print("The first time signature change is not at the beginning of the tune")
return None
# Remove consecutive duplicate time signatures (de-duplication)
processed_time_signature_changes = []
for idx, time_sig in enumerate(time_signature_changes):
if idx == 0:
processed_time_signature_changes.append(time_sig)
else:
prev_time_sig = time_signature_changes[idx-1]
# Only add time signature if it's different from the previous one
if not (prev_time_sig.numerator == time_sig.numerator and prev_time_sig.denominator == time_sig.denominator):
processed_time_signature_changes.append(time_sig)
# Normalize the time signatures to standard formats
normalized_time_signature_changes = []
for idx, time_signature in enumerate(processed_time_signature_changes):
if idx == len(time_signature_changes) - 1:
# If it's the last time signature change, set the next change point as the end of the piece
next_change_point = global_end
else:
# Otherwise, set the next change point as the next time signature's start time
next_change_point = time_signature_changes[idx+1].time
# Normalize the current time signature and extend the result
normalized_time_signature_changes.extend(self._normalize_time_signature(time_signature, ticks_per_beat, next_change_point))
# Return the list of processed and normalized time signatures
time_signature_changes = normalized_time_signature_changes
return time_signature_changes
def _half_step_interval_gap_check_across_instruments(self, instrument_note_dict):
'''
This function checks for half-step interval gaps between notes across different instruments.
It will avoid half-step intervals by keeping one note from any pair of notes that are a half-step apart,
regardless of which instrument they belong to.
'''
# order instrument_note_dict by pitch in descending order
instrument_note_dict = dict(sorted(instrument_note_dict.items()))
# Create a dictionary to store all pitches across instruments
all_pitches = {}
# Collect all pitches from each instrument and sort them in descending order
for instrument, notes in instrument_note_dict.items():
for pitch, durations in notes.items():
all_pitches[pitch] = all_pitches.get(pitch, []) + [(instrument, durations)]
# Sort the pitches in descending order
sorted_pitches = sorted(all_pitches.keys(), reverse=True)
# Create a new list to store the final pitches after comparison
final_pitch_list = []
# Use an index pointer to control the sliding window
idx = 0
while idx < len(sorted_pitches) - 1:
current_pitch = sorted_pitches[idx]
next_pitch = sorted_pitches[idx + 1]
if current_pitch - next_pitch == 1: # Check for a half-step interval gap
current_max_duration = max(duration for _, durations in all_pitches[current_pitch] for duration, _ in durations)
next_max_duration = max(duration for _, durations in all_pitches[next_pitch] for duration, _ in durations)
if current_max_duration < next_max_duration:
# Keep the higher pitch (next_pitch) and skip the current_pitch
final_pitch_list.append(next_pitch)
else:
# Keep the lower pitch (current_pitch) and skip the next_pitch
final_pitch_list.append(current_pitch)
# Skip the next pitch because we already handled it
idx += 2
else:
# No half-step gap, keep the current pitch and move to the next one
final_pitch_list.append(current_pitch)
idx += 1
# Ensure the last pitch is added if it's not part of a half-step interval
if idx == len(sorted_pitches) - 1:
final_pitch_list.append(sorted_pitches[-1])
# Filter out notes not in the final pitch list and update the instrument_note_dict
for instrument in instrument_note_dict.keys():
instrument_note_dict[instrument] = {
pitch: instrument_note_dict[instrument][pitch]
for pitch in sorted(instrument_note_dict[instrument].keys(), reverse=True) if pitch in final_pitch_list
}
return instrument_note_dict
def __call__(self, song_data, in_beat_resolution):
'''
Process a song's data to generate a sequence of musical events, including bars, chords, tempo,
and notes, similar to the approach used in the CP paper (corpus2event_remi_v2).
Parameters:
- song_data: A dictionary containing metadata, notes, chords, and tempos of the song.
- in_beat_resolution: The resolution of timing in beats (how many divisions per beat).
Returns:
- A sequence of musical events including start (SOS), bars, chords, tempo, instruments, notes,
and an end (EOS) event. If the time signature is invalid, returns None.
Procedure:
1. **Global Setup**:
- Extract global metadata like first and last note ticks, time signature changes, and ticks
per beat.
- Compute `in_beat_tick_resol`, the ratio of ticks per beat to the input beat resolution,
to assist in dividing bars later.
- Get a sorted list of instruments in the song.
2. **Time Signature Processing**:
- Call `_process_time_signature` to clean up and normalize the time signatures in the song.
- If the time signatures are invalid (e.g., no time signature changes or missing at the
start), the function exits early with None.
3. **Sequence Generation**:
- Initialize the sequence with a start token (SOS) and prepare variables for tracking
previous chord, tempo, and instrument states.
- Loop through each time signature change, dividing the song into measures based on the
current time signature's numerator and denominator.
- For each measure, append "Bar" tokens to mark measure boundaries, while ensuring that no
more than four consecutive empty bars are added.
- For each step within a measure, process the following:
- **Chords**: If there is a chord change, add a corresponding chord event.
- **Tempo**: If the tempo changes, add a tempo event.
- **Notes**: Iterate over each instrument, adding notes and checking for half-step
intervals, deduplicating notes, and choosing the longest duration for each pitch.
- Append a "Beat" event for each step with musical events.
4. **End Sequence**:
- Conclude the sequence by appending a final "Bar" token followed by an end token (EOS).
'''
# --- global tag --- #
first_note_tick = song_data['metadata']['first_note'] # Starting tick of the first note
global_end = song_data['metadata']['last_note'] # Ending tick of the last note
time_signature_changes = song_data['metadata']['time_signature'] # Time signature changes
ticks_per_beat = song_data['metadata']['ticks_per_beat'] # Ticks per beat resolution
# Resolution for dividing beats within measures, expressed as a fraction
in_beat_tick_resol = Fraction(ticks_per_beat, in_beat_resolution) # Example: 1024/12 -> (256, 3)
instrument_list = sorted(list(song_data['notes'].keys())) # Get a sorted list of instruments in the song
# --- process time signature --- #
# Normalize and process the time signatures in the song
time_signature_changes = self._process_time_signature(time_signature_changes, ticks_per_beat, first_note_tick, global_end)
if time_signature_changes == None:
return None # Exit if time signature is invalid
# --- create sequence --- #
prev_instr_idx = None # Track the previously processed instrument
final_sequence = []
final_sequence.append(self._create_event('SOS', None)) # Add Start of Sequence (SOS) token
prev_chord = None # Track the previous chord
prev_tempo = None # Track the previous tempo
chord_value = None
tempo_value = None
# Process each time signature change
for idx in range(len(time_signature_changes)):
time_sig_change_flag = True # Flag to indicate a time signature change
# Calculate bar resolution based on the current time signature
numerator = time_signature_changes[idx].numerator
denominator = time_signature_changes[idx].denominator
time_sig_name = f'time_signature_{numerator}/{denominator}' # Format time signature name
bar_resol = int(ticks_per_beat * numerator * (4 / denominator)) # Calculate bar resolution in ticks
bar_start_tick = time_signature_changes[idx].time # Start tick of the current bar
# Determine the next time signature change point or the end of the song
if idx == len(time_signature_changes) - 1:
next_change_point = global_end
else:
next_change_point = time_signature_changes[idx+1].time
# Process each measure within the current time signature
for measure_step in frange(bar_start_tick, next_change_point, bar_resol):
empty_bar_token = self._create_event('Bar', None) # Token for empty bars
# Ensure no more than 4 consecutive empty bars are added
if len(final_sequence) >= 4:
if not (final_sequence[-1] == empty_bar_token and final_sequence[-2] == empty_bar_token and
final_sequence[-3] == empty_bar_token and final_sequence[-4] == empty_bar_token):
if time_sig_change_flag:
final_sequence.append(self._create_event('Bar', time_sig_name)) # Mark new bar with time signature
else:
final_sequence.append(self._create_event('Bar', None))
else:
if time_sig_change_flag:
final_sequence.append(self._create_event('Bar', time_sig_name))
else:
if time_sig_change_flag:
final_sequence.append(self._create_event('Bar', time_sig_name))
else:
final_sequence.append(self._create_event('Bar', None))
time_sig_change_flag = False # Reset time signature change flag
# Process events within each beat
for in_beat_off_idx, beat_step in enumerate(frange(measure_step, measure_step + bar_resol, in_beat_tick_resol)):
events_list = []
# Retrieve chords and tempos at the current beat step
t_chords = song_data['chords'].get(beat_step)
t_tempos = song_data['tempos'].get(beat_step)
# Process chord and tempo if the number of features allows for it
if self.num_features in {8, 7}:
if t_chords is not None:
root, quality, _ = t_chords[-1].text.split('_') # Extract chord info
chord_value = root + '_' + quality
if t_tempos is not None:
tempo_value = t_tempos[-1].tempo # Extract tempo value
# Dictionary to track notes for each instrument to avoid duplicates
instrument_note_dict = defaultdict(dict)
# Process notes for each instrument at the current beat step
for instrument_idx in instrument_list:
t_notes = song_data['notes'][instrument_idx].get(beat_step)
# If there are notes at this beat step, process them.
if t_notes is not None:
# Track notes to avoid duplicates and check for half-step intervals
for note in t_notes:
if note.pitch not in instrument_note_dict[instrument_idx]:
instrument_note_dict[instrument_idx][note.pitch] = [(note.quantized_duration, note.velocity)]
else:
instrument_note_dict[instrument_idx][note.pitch].append((note.quantized_duration, note.velocity))
if len(instrument_note_dict) == 0:
continue
# Check for half-step interval gaps and handle them across instruments
pruned_instrument_note_dict = self._half_step_interval_gap_check_across_instruments(instrument_note_dict)
# add chord and tempo
if self.num_features in {7, 8}:
if prev_chord != chord_value:
events_list.append(self._create_event('Chord', chord_value))
prev_chord = chord_value
if prev_tempo != tempo_value:
events_list.append(self._create_event('Tempo', tempo_value))
prev_tempo = tempo_value
# add instrument and note
for instrument in pruned_instrument_note_dict:
if self.num_features in {5, 8}:
events_list.append(self._create_event('Instrument', instrument))
for pitch in pruned_instrument_note_dict[instrument]:
max_duration = max(pruned_instrument_note_dict[instrument][pitch], key=lambda x: x[0])
note_event = [
self._create_event('Note_Pitch', pitch),
self._create_event('Note_Duration', max_duration[0])
]
if self.num_features in {7, 8}:
note_event.append(self._create_event('Note_Velocity', max_duration[1]))
events_list.extend(note_event)
# If there are events in this step, add a "Beat" event and the collected events
if len(events_list):
final_sequence.append(self._create_event('Beat', in_beat_off_idx))
final_sequence.extend(events_list)
# --- end with BAR & EOS --- #
final_sequence.append(self._create_event('Bar', None)) # Add final bar token
final_sequence.append(self._create_event('EOS', None)) # Add End of Sequence (EOS) token
return final_sequence
################################# for CP style encoding #################################
class Corpus2event_cp(Corpus2event_remi):
def __init__(self, num_features):
super().__init__(num_features)
self.num_features = num_features
self._init_event_template()
def _init_event_template(self):
'''
The order of musical features is Type, Beat, Chord, Tempo, Instrument, Pitch, Duration, Velocity
'''
self.event_template = {}
if self.num_features == 8:
feature_names = ['type', 'beat', 'chord', 'tempo', 'instrument', 'pitch', 'duration', 'velocity']
elif self.num_features == 7:
feature_names = ['type', 'beat', 'chord', 'tempo', 'pitch', 'duration', 'velocity']
elif self.num_features == 5:
feature_names = ['type', 'beat', 'instrument', 'pitch', 'duration']
elif self.num_features == 4:
feature_names = ['type', 'beat', 'pitch', 'duration']
for feature_name in feature_names:
self.event_template[feature_name] = 0
def create_cp_sos_event(self):
total_event = self.event_template.copy()
total_event['type'] = 'SOS'
return total_event
def create_cp_eos_event(self):
total_event = self.event_template.copy()
total_event['type'] = 'EOS'
return total_event
def create_cp_metrical_event(self, pos, chord, tempo):
'''
when the compound token is related to metrical information
'''
meter_event = self.event_template.copy()
meter_event['type'] = 'Metrical'
meter_event['beat'] = pos
if self.num_features == 7 or self.num_features == 8:
meter_event['chord'] = chord
meter_event['tempo'] = tempo
return meter_event
def create_cp_note_event(self, instrument_name, pitch, duration, velocity):
'''
when the compound token is related to note information
'''
note_event = self.event_template.copy()
note_event['type'] = 'Note'
note_event['pitch'] = pitch
note_event['duration'] = duration
if self.num_features == 5 or self.num_features == 8:
note_event['instrument'] = instrument_name
if self.num_features == 7 or self.num_features == 8:
note_event['velocity'] = velocity
return note_event
def create_cp_bar_event(self, time_sig_change_flag=False, time_sig_name=None):
meter_event = self.event_template.copy()
if time_sig_change_flag:
meter_event['type'] = 'Metrical'
meter_event['beat'] = f'Bar_{time_sig_name}'
else:
meter_event['type'] = 'Metrical'
meter_event['beat'] = 'Bar'
return meter_event
def __call__(self, song_data, in_beat_resolution):
# --- global tag --- #
first_note_tick = song_data['metadata']['first_note'] # First note timestamp in ticks
global_end = song_data['metadata']['last_note'] # Last note timestamp in ticks
time_signature_changes = song_data['metadata']['time_signature'] # Time signature changes throughout the song
ticks_per_beat = song_data['metadata']['ticks_per_beat'] # Ticks per beat (resolution of the timing grid)
in_beat_tick_resol = Fraction(ticks_per_beat, in_beat_resolution) # Tick resolution for beats
instrument_list = sorted(list(song_data['notes'].keys())) # List of instruments in the song
# --- process time signature --- #
# Process time signature changes and adjust them for the given song structure
time_signature_changes = self._process_time_signature(time_signature_changes, ticks_per_beat, first_note_tick, global_end)
if time_signature_changes == None:
return None # Exit if no valid time signature changes found
# --- create sequence --- #
final_sequence = [] # Initialize the final sequence to store the events
final_sequence.append(self.create_cp_sos_event()) # Add the Start-of-Sequence (SOS) event
chord_text = None # Placeholder for the current chord
tempo_text = None # Placeholder for the current tempo
# Loop through each time signature change and process the corresponding measures
for idx in range(len(time_signature_changes)):
time_sig_change_flag = True # Flag to track when time signature changes
# Calculate bar resolution (number of ticks per bar based on the time signature)
numerator = time_signature_changes[idx].numerator
denominator = time_signature_changes[idx].denominator
time_sig_name = f'time_signature_{numerator}/{denominator}' # Format the time signature as a string
bar_resol = int(ticks_per_beat * numerator * (4 / denominator)) # Calculate number of ticks per bar
bar_start_tick = time_signature_changes[idx].time # Starting tick for this time signature
# Determine the point for the next time signature change or the end of the song
if idx == len(time_signature_changes) - 1:
next_change_point = global_end
else:
next_change_point = time_signature_changes[idx + 1].time
# Iterate over each measure (bar) between the current and next time signature change
for measure_step in frange(bar_start_tick, next_change_point, bar_resol):
empty_bar_token = self.create_cp_bar_event() # Create an empty bar event
# Check if the last four events in the sequence are consecutive empty bars
if len(final_sequence) >= 4:
if not (final_sequence[-1] == empty_bar_token and final_sequence[-2] == empty_bar_token and final_sequence[-3] == empty_bar_token and final_sequence[-4] == empty_bar_token):
final_sequence.append(self.create_cp_bar_event(time_sig_change_flag, time_sig_name))
else:
if time_sig_change_flag:
final_sequence.append(self.create_cp_bar_event(time_sig_change_flag, time_sig_name))
else:
final_sequence.append(self.create_cp_bar_event(time_sig_change_flag, time_sig_name))
# Reset the time signature change flag after handling the bar event
time_sig_change_flag = False
# Loop through beats in each measure based on the in-beat resolution
for in_beat_off_idx, beat_step in enumerate(frange(measure_step, measure_step + bar_resol, in_beat_tick_resol)):
chord_tempo_flag = False # Flag to track if chord and tempo events are added
events_list = [] # List to hold events for the current beat
pos_text = 'Beat_' + str(in_beat_off_idx) # Create a beat event label
# --- chord & tempo processing --- #
# Unpack chords and tempos for the current beat step
t_chords = song_data['chords'].get(beat_step)
t_tempos = song_data['tempos'].get(beat_step)
# If a chord is present, extract its root, quality, and bass
if self.num_features in {7, 8}:
if t_chords is not None:
root, quality, _ = t_chords[-1].text.split('_')
chord_text = 'Chord_' + root + '_' + quality
# If a tempo is present, format it as a string
if t_tempos is not None:
tempo_text = 'Tempo_' + str(t_tempos[-1].tempo)
# Dictionary to track notes for each instrument to avoid duplicates
instrument_note_dict = defaultdict(dict)
# --- instrument & note processing --- #
# Loop through each instrument and process its notes at the current beat step
for instrument_idx in instrument_list:
t_notes = song_data['notes'][instrument_idx].get(beat_step)
# If notes are present, process them
if t_notes != None:
# Track notes and their properties (duration and velocity) for the current instrument
for note in t_notes:
if note.pitch not in instrument_note_dict[instrument_idx]:
instrument_note_dict[instrument_idx][note.pitch] = [(note.quantized_duration, note.velocity)]
else:
instrument_note_dict[instrument_idx][note.pitch].append((note.quantized_duration, note.velocity))
if len(instrument_note_dict) == 0:
continue
# Check for half-step interval gaps and handle them across instruments
pruned_instrument_note_dict = self._half_step_interval_gap_check_across_instruments(instrument_note_dict)
# add chord and tempo
if self.num_features in {7, 8}:
if not chord_tempo_flag:
if chord_text == None:
chord_text = 'Chord_N_N'
if tempo_text == None:
tempo_text = 'Tempo_N_N'
chord_tempo_flag = True
events_list.append(self.create_cp_metrical_event(pos_text, chord_text, tempo_text))
# add instrument and note
for instrument_idx in pruned_instrument_note_dict:
instrument_name = 'Instrument_' + str(instrument_idx)
for pitch in pruned_instrument_note_dict[instrument_idx]:
max_duration = max(pruned_instrument_note_dict[instrument_idx][pitch], key=lambda x: x[0])
note_pitch_text = 'Note_Pitch_' + str(pitch)
note_duration_text = 'Note_Duration_' + str(max_duration[0])
note_velocity_text = 'Note_Velocity_' + str(max_duration[1])
events_list.append(self.create_cp_note_event(instrument_name, note_pitch_text, note_duration_text, note_velocity_text))
# If there are any events for this beat, add them to the final sequence
if len(events_list) > 0:
final_sequence.extend(events_list)
# --- end with BAR & EOS --- #
final_sequence.append(self.create_cp_bar_event()) # Add the final bar event
final_sequence.append(self.create_cp_eos_event()) # Add the End-of-Sequence (EOS) event
return final_sequence # Return the final sequence of events
################################# for NB style encoding #################################
class Corpus2event_nb(Corpus2event_cp):
def __init__(self, num_features):
'''
For convenience in logging, we use "type" word for "metric" sub-token in the code to compare easily with other encoding schemes
'''
super().__init__(num_features)
self.num_features = num_features
self._init_event_template()
def _init_event_template(self):
self.event_template = {}
if self.num_features == 8:
feature_names = ['type', 'beat', 'chord', 'tempo', 'instrument', 'pitch', 'duration', 'velocity']
elif self.num_features == 7:
feature_names = ['type', 'beat', 'chord', 'tempo', 'pitch', 'duration', 'velocity']
elif self.num_features == 5:
feature_names = ['type', 'beat', 'instrument', 'pitch', 'duration']
elif self.num_features == 4:
feature_names = ['type', 'beat', 'pitch', 'duration']
for feature_name in feature_names:
self.event_template[feature_name] = 0
def create_nb_sos_event(self):
total_event = self.event_template.copy()
total_event['type'] = 'SOS'
return total_event
def create_nb_eos_event(self):
total_event = self.event_template.copy()
total_event['type'] = 'EOS'
return total_event
def create_nb_event(self, bar_beat_type, pos, chord, tempo, instrument_name, pitch, duration, velocity):
total_event = self.event_template.copy()
total_event['type'] = bar_beat_type
total_event['beat'] = pos
total_event['pitch'] = pitch
total_event['duration'] = duration
if self.num_features in {5, 8}:
total_event['instrument'] = instrument_name
if self.num_features in {7, 8}:
total_event['chord'] = chord
total_event['tempo'] = tempo
total_event['velocity'] = velocity
return total_event
def create_nb_empty_bar_event(self):
total_event = self.event_template.copy()
total_event['type'] = 'Empty_Bar'
return total_event
def get_bar_beat_idx(self, bar_flag, beat_flag, time_sig_name, time_sig_change_flag):
'''
This function is to get the metric information for the current bar and beat
There are four types of metric information: NNN, SNN, SSN, SSS
Each letter represents the change of time signature, bar, and beat (new or same)
'''
if time_sig_change_flag: # new time signature
return "NNN_" + time_sig_name
else:
if bar_flag and beat_flag: # same time sig & new bar & new beat
return "SNN"
elif not bar_flag and beat_flag: # same time sig & same bar & new beat
return "SSN"
elif not bar_flag and not beat_flag: # same time sig & same bar & same beat
return "SSS"
def __call__(self, song_data, in_beat_resolution:int):
# --- global tag --- #
first_note_tick = song_data['metadata']['first_note'] # First note timestamp in ticks
global_end = song_data['metadata']['last_note'] # Last note timestamp in ticks
time_signature_changes = song_data['metadata']['time_signature'] # Time signature changes throughout the song
ticks_per_beat = song_data['metadata']['ticks_per_beat'] # Ticks per beat (resolution of the timing grid)
in_beat_tick_resol = Fraction(ticks_per_beat, in_beat_resolution) # Tick resolution for beats
instrument_list = sorted(list(song_data['notes'].keys())) # List of instruments in the song
# --- process time signature --- #
# Process time signature changes and adjust them for the given song structure
time_signature_changes = self._process_time_signature(time_signature_changes, ticks_per_beat, first_note_tick, global_end)
if time_signature_changes == None:
return None # Exit if no valid time signature changes found
# --- create sequence --- #
final_sequence = [] # Initialize the final sequence to store the events
final_sequence.append(self.create_nb_sos_event()) # Add the Start-of-Sequence (SOS) event
chord_text = None # Placeholder for the current chord
tempo_text = None # Placeholder for the current tempo
# Loop through each time signature change and process the corresponding measures
for idx in range(len(time_signature_changes)):
time_sig_change_flag = True # Flag to track when time signature changes
# Calculate bar resolution (number of ticks per bar based on the time signature)
numerator = time_signature_changes[idx].numerator
denominator = time_signature_changes[idx].denominator
time_sig_name = f'time_signature_{numerator}/{denominator}' # Format the time signature as a string
bar_resol = int(ticks_per_beat * numerator * (4 / denominator)) # Calculate number of ticks per bar
bar_start_tick = time_signature_changes[idx].time # Starting tick for this time signature
# Determine the point for the next time signature change or the end of the song
if idx == len(time_signature_changes) - 1:
next_change_point = global_end
else:
next_change_point = time_signature_changes[idx + 1].time
# Iterate over each measure (bar) between the current and next time signature change
for measure_step in frange(bar_start_tick, next_change_point, bar_resol):
bar_flag = True
note_flag = False
# Loop through beats in each measure based on the in-beat resolution
for in_beat_off_idx, beat_step in enumerate(frange(measure_step, measure_step + bar_resol, in_beat_tick_resol)):
beat_flag = True
events_list = []
pos_text = 'Beat_' + str(in_beat_off_idx)
# --- chord & tempo processing --- #
# Unpack chords and tempos for the current beat step
t_chords = song_data['chords'].get(beat_step)
t_tempos = song_data['tempos'].get(beat_step)
# If a chord is present, extract its root, quality, and bass
if self.num_features == 8 or self.num_features == 7:
if t_chords is not None:
root, quality, _ = t_chords[-1].text.split('_')
chord_text = 'Chord_' + root + '_' + quality
# If a tempo is present, format it as a string
if t_tempos is not None:
tempo_text = 'Tempo_' + str(t_tempos[-1].tempo)
# Dictionary to track notes for each instrument to avoid duplicates
instrument_note_dict = defaultdict(dict)
# --- instrument & note processing --- #
# Loop through each instrument and process its notes at the current beat step
for instrument_idx in instrument_list:
t_notes = song_data['notes'][instrument_idx].get(beat_step)
# If notes are present, process them
if t_notes != None:
note_flag = True
# Track notes and their properties (duration and velocity) for the current instrument
for note in t_notes:
if note.pitch not in instrument_note_dict[instrument_idx]:
instrument_note_dict[instrument_idx][note.pitch] = [(note.quantized_duration, note.velocity)]
else:
instrument_note_dict[instrument_idx][note.pitch].append((note.quantized_duration, note.velocity))
# # Check for half-step interval gaps and handle them accordingly
# self._half_step_interval_gap_check(instrument_note_dict, instrument_idx)
if len(instrument_note_dict) == 0:
continue
# Check for half-step interval gaps and handle them across instruments
pruned_instrument_note_dict = self._half_step_interval_gap_check_across_instruments(instrument_note_dict)
# add chord and tempo
if self.num_features in {7, 8}:
if chord_text == None:
chord_text = 'Chord_N_N'
if tempo_text == None:
tempo_text = 'Tempo_N_N'
# add instrument and note
for instrument_idx in pruned_instrument_note_dict:
instrument_name = 'Instrument_' + str(instrument_idx)
for pitch in pruned_instrument_note_dict[instrument_idx]:
max_duration = max(pruned_instrument_note_dict[instrument_idx][pitch], key=lambda x: x[0])
note_pitch_text = 'Note_Pitch_' + str(pitch)
note_duration_text = 'Note_Duration_' + str(max_duration[0])
note_velocity_text = 'Note_Velocity_' + str(max_duration[1])
bar_beat_type = self.get_bar_beat_idx(bar_flag, beat_flag, time_sig_name, time_sig_change_flag)
events_list.append(self.create_nb_event(bar_beat_type, pos_text, chord_text, tempo_text, instrument_name, note_pitch_text, note_duration_text, note_velocity_text))
bar_flag = False
beat_flag = False
time_sig_change_flag = False
# If there are any events for this beat, add them to the final sequence
if events_list != None and len(events_list):
final_sequence.extend(events_list)
# when there is no note in this bar
if not note_flag:
# avoid consecutive empty bars (more than 4 is not allowed)
empty_bar_token = self.create_nb_empty_bar_event()
if len(final_sequence) >= 4:
if final_sequence[-1] == empty_bar_token and final_sequence[-2] == empty_bar_token and final_sequence[-3] == empty_bar_token and final_sequence[-4] == empty_bar_token:
continue
final_sequence.append(empty_bar_token)
# --- end with BAR & EOS --- #
final_sequence.append(self.create_nb_eos_event())
return final_sequence