first commit

This commit is contained in:
2025-09-08 14:49:28 +08:00
commit 80333dff74
160 changed files with 30655 additions and 0 deletions

View File

View File

@ -0,0 +1,46 @@
from sf2utils.sf2parse import Sf2File
def print_sorted_presets(sf2_path):
presets_info = []
with open(sf2_path, 'rb') as f:
sf2 = Sf2File(f)
for preset in sf2.presets:
try:
# 尝试直接读取
name = getattr(preset, 'name', '???').strip('\x00')
bank = getattr(preset, 'bank', None)
program = getattr(preset, 'preset', None)
# 如果获取不到,再尝试从子属性中取
if bank is None or program is None:
for attr in dir(preset):
attr_value = getattr(preset, attr)
if hasattr(attr_value, 'bank') and hasattr(attr_value, 'preset'):
bank = attr_value.bank
program = attr_value.preset
name = getattr(attr_value, 'name', name).strip('\x00')
break
# 收集有效结果
if bank is not None and program is not None:
presets_info.append((program, bank, name))
except Exception as e:
print(f"Error reading preset: {e}")
# 按 program 升序排序(若需要按 bank 再 program改为 sorted(..., key=lambda x: (x[1], x[0]))
presets_info.sort(key=lambda x: x[0])
# 打印结果
print(f"{'Program':<8} {'Bank':<6} {'Preset Name'}")
print("-" * 40)
for program, bank, name in presets_info:
print(f"{program:<8} {bank:<6} {name}")
# DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2'
# DEFAULT_SOUND_FONT = '~/.fluidsynth/default_sound_font.sf2'
# 替换为你的 sf2 文件路径
sf2_path = "/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2"
print_sorted_presets(sf2_path)

View File

@ -0,0 +1,94 @@
import random
from typing import Union
import torch
class Augmentor:
def __init__(
self,
vocab,
aug_type:Union[str, None],
input_length:int
):
self.vocab = vocab
self.aug_type = aug_type
self.input_length = input_length
self.feature_list = vocab.feature_list
self.num_features = len(self.feature_list)
self.encoding_scheme = vocab.encoding_scheme
self.pitch_idx = self.feature_list.index('pitch')
if 'chord' in self.feature_list:
self.chord_idx = self.feature_list.index('chord')
def _get_shift(self, segment):
# the pitch vocab has ignore token in 0 index
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
pitch_mask = segment != 0
pitch_segment = segment[pitch_mask[:,self.pitch_idx], self.pitch_idx]
# check if tensor is empty
if pitch_segment.numel() == 0:
shift = 0
else:
lowest_pitch = max(12, torch.min(pitch_segment))
highest_pitch = min(119, torch.max(pitch_segment))
lower_shift_bound = torch.where(lowest_pitch - torch.arange(6) > 11)[0][-1].item()
upper_shift_bound = torch.where(highest_pitch + torch.arange(7) < 120)[0][-1].item()
shift = random.randint(-lower_shift_bound, upper_shift_bound)
else: # remi
mask_for_pitch = self.vocab.total_mask['pitch'].to(segment.device)
segemnt_pitch_mask = mask_for_pitch[segment]
segment_pitch = segment * segemnt_pitch_mask
segment_pitch = segment_pitch[segment_pitch != 0]
# check if tensor is empty
if segment_pitch.numel() == 0:
shift = 0
else:
lower_bound = torch.argwhere(mask_for_pitch == 1)[0].item()
upper_bound = torch.argwhere(mask_for_pitch == 1)[-1].item()
lowest_pitch = max(lower_bound, torch.min(segment_pitch))
highest_pitch = min(upper_bound, torch.max(segment_pitch))
lower_shift_bound = torch.where(lowest_pitch - torch.arange(6) >= lower_bound)[0][-1].item()
upper_shift_bound = torch.where(highest_pitch + torch.arange(7) <= upper_bound)[0][-1].item()
shift = random.randint(-lower_shift_bound, upper_shift_bound)
return shift
# TODO: arrange hard coded part
def __call__(self, segment):
'''
input_tensor is segments of x, y
for transformer_xl, the shape of x, y is [max_num_segments, input_length, num_features]
so we need to change the shape of x, y to [max_num_segments*input_length, num_features]
'''
if self.aug_type == 'random':
shift = self._get_shift(segment)
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
# pitch augmentation
segment_pitch_mask = segment != 0
new_segment = segment.clone()
new_segment[segment_pitch_mask[:,self.pitch_idx], self.pitch_idx] += shift
if 'chord' in self.feature_list:
# chord augmentation
segment_chord_mask = (segment[:,self.chord_idx] != 0) & (segment[:,self.chord_idx] != 1)
new_segment[segment_chord_mask, self.chord_idx] = (((new_segment[segment_chord_mask, self.chord_idx]-2) % 12) + shift ) % 12 + ((new_segment[segment_chord_mask, self.chord_idx]-2) // 12) * 12 + 2
segment = new_segment
else: # remi
# choose random interger between -5 and 6
# the augmented results from shift -6 and 6 are same, so we choose -5 and 6
# pitch augmentation
mask_for_pitch = self.vocab.total_mask['pitch'].to(segment.device)
segment_pitch_mask = mask_for_pitch[segment]
new_segment = segment.clone()
new_segment_valid = (new_segment + shift) * segment_pitch_mask
new_segment = new_segment * (1 - segment_pitch_mask) + new_segment_valid
if 'chord' in self.feature_list:
# chord augmentation
mask_for_chord = self.vocab.total_mask['chord'].clone().to(segment.device)
chord_n_n_idx = torch.argwhere(mask_for_chord == 1)[-1].item()
mask_for_chord[chord_n_n_idx] = 0
start_idx_chord = self.vocab.remi_vocab_boundaries_by_key['chord'][0]
segment_chord_mask = mask_for_chord[segment]
new_segment_valid = ((((new_segment - start_idx_chord) % 12 + shift) % 12) + ((new_segment - start_idx_chord) // 12) * 12 + start_idx_chord) * segment_chord_mask
new_segment = new_segment * (1 - segment_chord_mask) + new_segment_valid
segment = new_segment
return segment

View File

@ -0,0 +1,207 @@
import random
from collections import defaultdict
import torch
import numpy as np
import random
def reverse_shift_and_pad(tune_in_idx, slice_boundary=4):
new_lst = [curr_elems[:slice_boundary] + next_elems[slice_boundary:] for curr_elems, next_elems in zip(tune_in_idx, tune_in_idx[1:])]
return new_lst
def reverse_shift_and_pad_for_tensor(tensor, first_pred_feature):
'''
tensor: [batch_size x seq_len x feature_size]
'''
if first_pred_feature == 'type':
return tensor
if tensor.shape[-1] == 8:
slice_boundary_dict = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
elif tensor.shape[-1] == 7:
slice_boundary_dict = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'pitch':4, 'duration':5, 'velocity':6}
elif tensor.shape[-1] == 5:
slice_boundary_dict = {'type':0, 'beat':1, 'instrument':2, 'pitch':3, 'duration':4}
elif tensor.shape[-1] == 4:
slice_boundary_dict = {'type':0, 'beat':1, 'pitch':2, 'duration':3}
slice_boundary = slice_boundary_dict[first_pred_feature]
new_tensor = torch.zeros_like(tensor)
new_tensor[..., :, :slice_boundary] = tensor[..., :, :slice_boundary]
new_tensor[..., :-1, slice_boundary:] = tensor[..., 1:, slice_boundary:]
return new_tensor
def shift_and_pad(tune_in_idx, first_pred_feature):
if first_pred_feature == 'type':
return tune_in_idx
if len(tune_in_idx[0]) == 8:
slice_boundary_dict = {'type':0, 'beat':-7, 'chord':-6, 'tempo':-5, 'instrument':-4, 'pitch':-3, 'duration':-2, 'velocity':-1}
elif len(tune_in_idx[0]) == 7:
slice_boundary_dict = {'type':0, 'beat':-6, 'chord':-5, 'tempo':-4, 'pitch':-3, 'duration':-2, 'velocity':-1}
elif len(tune_in_idx[0]) == 5:
slice_boundary_dict = {'type':0, 'beat':-4, 'instrument':-3, 'pitch':-2, 'duration':-1}
elif len(tune_in_idx[0]) == 4:
slice_boundary_dict = {'type':0, 'beat':-3, 'pitch':-2, 'duration':-1}
slice_boundary = slice_boundary_dict[first_pred_feature]
# Add an empty list padded with zeros at the beginning, and sos and eos tokens are not shifted
padded_tune_in_idx = torch.cat([torch.zeros(1, len(tune_in_idx[0]), dtype=torch.long), tune_in_idx], dim=0)
new_tensor = torch.zeros_like(padded_tune_in_idx)
new_tensor[:, slice_boundary:] = padded_tune_in_idx[:, slice_boundary:]
new_tensor[:-1, :slice_boundary] = padded_tune_in_idx[1:, :slice_boundary]
return new_tensor
class VanillaTransformer_compiler():
def __init__(
self,
data_list,
augmentor,
eos_token,
input_length,
first_pred_feature,
encoding_scheme
):
self.data_list = data_list
self.augmentor = augmentor
self.eos_token = eos_token
self.input_length = input_length
self.first_pred_feature = first_pred_feature
self.encoding_scheme = encoding_scheme
def make_segments(self, data_type):
segments = []
tune_name2segment = defaultdict(list)
segment2tune_name = []
num_segments = 0
for i in range(len(self.data_list)):
tune_in_idx, tune_name = self.data_list[i]
tune_in_idx = torch.LongTensor(tune_in_idx)
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
eos_token = torch.LongTensor(self.eos_token)
else:
eos_token = torch.LongTensor(self.eos_token)
# shift and pad
tune_in_idx = shift_and_pad(tune_in_idx, self.first_pred_feature)
if data_type == 'train':
if len(tune_in_idx) <= self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx), 1)
mask = torch.cat([torch.ones(len(tune_in_idx), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([tune_in_idx, padding_seq], dim=0)
segments.append([segment, mask])
segment2tune_name.append(tune_name)
else:
start_point = 0
while start_point + self.input_length+1 < len(tune_in_idx):
mask = torch.ones(self.input_length+1, dtype=torch.long)
segment = tune_in_idx[start_point:start_point + self.input_length+1]
segments.append([segment, mask])
segment2tune_name.append(tune_name)
assert len(segment) == self.input_length+1
# Randomly choose the start point for the next segment, which is in the range of half of the current segment to the end of the current segment
start_point += random.randint((self.input_length+1)//2, self.input_length+1)
# if text controled,we only use the first segment
# add the last segment
if len(tune_in_idx[start_point:]) < self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx[start_point:]))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx[start_point:]), 1)
mask = torch.cat([torch.ones(len(tune_in_idx[start_point:]), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([tune_in_idx[start_point:], padding_seq], dim=0)
segments.append([segment, mask])
segment2tune_name.append(tune_name)
else: # for validset
for i in range(0, len(tune_in_idx), self.input_length+1):
segment = tune_in_idx[i:i+self.input_length+1]
if len(segment) <= self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(segment))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(segment), 1)
mask = torch.cat([torch.ones(len(segment), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([segment, padding_seq], dim=0)
segment2tune_name.append(tune_name)
segments.append([segment, mask])
num_segments += 1
tune_name2segment[tune_name].append(num_segments-1)
else:
mask = torch.ones(self.input_length+1, dtype=torch.long)
segments.append([segment, mask])
segment2tune_name.append(tune_name)
segments.append([segment, mask])
num_segments += 1
tune_name2segment[tune_name].append(num_segments-1)
assert len(segment) == self.input_length+1
return segments, tune_name2segment, segment2tune_name
def make_segments_iters(self, data_type):
tune_name2segment = defaultdict(list)
segment2tune_name = []
num_segments = 0
# shuffle the data_list
if data_type == 'train':
random.shuffle(self.data_list)
print("length of data_list:", len(self.data_list))
for i in range(len(self.data_list)):
tune_in_idx, tune_name = self.data_list[i]
tune_in_idx = torch.LongTensor(tune_in_idx)
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
eos_token = torch.LongTensor(self.eos_token)
else:
eos_token = torch.LongTensor(self.eos_token)
# shift and pad
tune_in_idx = shift_and_pad(tune_in_idx, self.first_pred_feature)
if data_type == 'train':
if len(tune_in_idx) <= self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx), 1)
mask = torch.cat([torch.ones(len(tune_in_idx), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([tune_in_idx, padding_seq], dim=0)
segment2tune_name.append(tune_name)
yield [segment, mask], tune_name2segment, segment2tune_name
else:
start_point = 0
while start_point + self.input_length+1 < len(tune_in_idx):
mask = torch.ones(self.input_length+1, dtype=torch.long)
segment = tune_in_idx[start_point:start_point + self.input_length+1]
segment2tune_name.append(tune_name)
yield [segment, mask], tune_name2segment, segment2tune_name
assert len(segment) == self.input_length+1
start_point += random.randint((self.input_length+1)//2, self.input_length+1)
# break
if len(tune_in_idx[start_point:]) < self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx[start_point:]))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx[start_point:]), 1)
mask = torch.cat([torch.ones(len(tune_in_idx[start_point:]), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([tune_in_idx[start_point:], padding_seq], dim=0)
segment2tune_name.append(tune_name)
yield [segment, mask], tune_name2segment, segment2tune_name
else: # for validset
for i in range(0, len(tune_in_idx), self.input_length+1):
segment = tune_in_idx[i:i+self.input_length+1]
if len(segment) <= self.input_length+1:
if 'remi' in self.encoding_scheme:
padding_seq = eos_token[0].repeat(self.input_length+1-len(segment))
else:
padding_seq = eos_token.repeat(self.input_length+1-len(segment), 1)
mask = torch.cat([torch.ones(len(segment), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
segment = torch.cat([segment, padding_seq], dim=0)
segment2tune_name.append(tune_name)
num_segments += 1
tune_name2segment[tune_name].append(num_segments-1)
yield [segment, mask], tune_name2segment, segment2tune_name
else:
mask = torch.ones(self.input_length+1, dtype=torch.long)
segment2tune_name.append(tune_name)
num_segments += 1
tune_name2segment[tune_name].append(num_segments-1)
yield [segment, mask], tune_name2segment, segment2tune_name
assert len(segment) == self.input_length+1

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,404 @@
import os, sys
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
from music21 import converter
import muspy
import miditoolkit
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature
from .midi2audio import FluidSynth
from data_representation.constants import PROGRAM_INSTRUMENT_MAP
class MuteWarn:
def __enter__(self):
self._init_stdout = sys.stdout
sys.stdout = open(os.devnull, "w")
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout.close()
sys.stdout = self._init_stdout
def save_score_image_from_midi(midi_fn, file_name):
assert isinstance(midi_fn, str)
with MuteWarn():
convert = converter.parse(midi_fn)
convert.write('musicxml.png', fp=file_name)
def save_pianoroll_image_from_midi(midi_fn, file_name):
assert isinstance(midi_fn, str)
midi_obj_muspy = muspy.read_midi(midi_fn)
midi_obj_muspy.show_pianoroll(track_label='program', preset='frame')
plt.gcf().set_size_inches(20, 10)
plt.savefig(file_name)
plt.close()
def save_wav_from_midi(midi_fn, file_name, qpm=80):
assert isinstance(midi_fn, str)
with MuteWarn():
music = muspy.read_midi(midi_fn)
music.tempos[0].qpm = qpm
music.write_audio(file_name, rate=44100, gain=3)
def save_wav_from_midi_fluidsynth(midi_fn, file_name, gain=3):
assert isinstance(midi_fn, str)
fs = FluidSynth(gain=gain)
fs.midi_to_audio(midi_fn, file_name)
class MidiDecoder4REMI:
def __init__(
self,
vocab,
in_beat_resolution,
dataset_name
):
self.vocab = vocab
self.in_beat_resolution = in_beat_resolution
self.dataset_name = dataset_name
if dataset_name == 'SymphonyMIDI':
self.gain = 0.7
elif dataset_name == 'SOD' or dataset_name == 'LakhClean':
self.gain = 1.1
elif dataset_name == 'Pop1k7' or dataset_name == 'Pop909':
self.gain = 2.5
else:
self.gain = 1.5
def __call__(self, generated_output, output_path=None):
'''
generated_output: list of tensor, the tensor
'''
idx2event = self.vocab.idx2event
if generated_output.dim() == 2:
generated_output = generated_output.squeeze(0)
events = [idx2event[token.item()] for token in generated_output]
midi_obj = miditoolkit.midi.parser.MidiFile()
if 'tempo' not in idx2event.keys():
default_tempo = 95
midi_obj.tempo_changes.append(
TempoChange(tempo=default_tempo, time=0))
default_ticks_per_beat = 480
default_in_beat_ticks = 480 // self.in_beat_resolution
cur_pos = 0
bar_pos = 0
cur_bar_resol = 0
beat_pos = 0
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
instr_notes_dict = defaultdict(list)
for i in range(len(events)-2):
cur_event = events[i]
# print(cur_event)
name = cur_event.split('_')[0]
attr = cur_event.split('_')
if name == 'Bar':
bar_pos += cur_bar_resol
if 'time' in cur_event:
cur_num, cur_denom = attr[-1].split('/')
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
elif name == 'Beat':
beat_pos = int(attr[1])
cur_pos = bar_pos + beat_pos * default_in_beat_ticks
elif name == 'Chord':
chord_text = attr[1] + '_' + attr[2]
midi_obj.markers.append(Marker(text=chord_text, time=cur_pos))
elif name == 'Tempo':
midi_obj.tempo_changes.append(
TempoChange(tempo=int(attr[1]), time=cur_pos))
elif name == 'Instrument':
cur_instr = int(attr[1])
else:
if len(self.vocab.feature_list) == 7 or len(self.vocab.feature_list) == 8:
if 'Note_Pitch' in events[i] and \
'Note_Duration' in events[i+1] and \
'Note_Velocity' in events[i+2]:
pitch = int(events[i].split('_')[-1])
duration = int(events[i+1].split('_')[-1])
duration = duration * default_in_beat_ticks
end = cur_pos + duration
velocity = int(events[i+2].split('_')[-1])
instr_notes_dict[cur_instr].append(
Note(
pitch=pitch,
start=cur_pos,
end=end,
velocity=velocity))
elif len(self.vocab.feature_list) == 4 or len(self.vocab.feature_list) == 5:
if 'Note_Pitch' in events[i] and \
'Note_Duration' in events[i+1]:
pitch = int(events[i].split('_')[-1])
duration = int(events[i+1].split('_')[-1])
duration = duration * default_in_beat_ticks
end = cur_pos + duration
velocity = 90
instr_notes_dict[cur_instr].append(
Note(
pitch=pitch,
start=cur_pos,
end=end,
velocity=velocity))
# save midi
for instr, notes in instr_notes_dict.items():
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
if instr == 114: is_drum = True
else: is_drum = False
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
instr_track.notes = notes
midi_obj.instruments.append(instr_track)
if isinstance(output_path, str) or isinstance(output_path, Path):
output_path = str(output_path)
# make subdir
music_path = os.path.join(os.path.dirname(output_path), 'music')
prompt_music_path = os.path.join(os.path.dirname(output_path), 'prompt_music')
if not os.path.exists(music_path):
os.makedirs(music_path)
if not os.path.exists(prompt_music_path):
os.makedirs(prompt_music_path)
# if not contain 'prompt' in output_path, save prompt music
if 'prompt' in output_path:
music_path = os.path.join(prompt_music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
else:
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
midi_obj.dump(output_path)
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
return midi_obj
class MidiDecoder4CP(MidiDecoder4REMI):
def __init__(self, vocab, in_beat_resolution, dataset_name):
super().__init__(vocab, in_beat_resolution, dataset_name)
def _update_chord_tempo(self, midi_obj, cur_pos, token_with_7infos, feature2idx):
if len(feature2idx) == 7 or len(feature2idx) == 8:
# chord
if token_with_7infos[feature2idx['chord']] != 'CONTI' and token_with_7infos[feature2idx['chord']] != 0:
midi_obj.markers.append(
Marker(text=str(token_with_7infos[feature2idx['chord']]), time=cur_pos))
# tempo
if token_with_7infos[feature2idx['tempo']] != 'CONTI' and token_with_7infos[feature2idx['tempo']] != 0 and token_with_7infos[feature2idx['tempo']] != "Tempo_N_N":
tempo = int(token_with_7infos[feature2idx['tempo']].split('_')[-1])
midi_obj.tempo_changes.append(
TempoChange(tempo=tempo, time=cur_pos))
return midi_obj
elif len(feature2idx) == 4 or len(feature2idx) == 5:
return midi_obj
def __call__(self, generated_output, output_path=None):
'''
generated_output: tensor, batch x seq_len x num_types
num_types includes: type, tempo, chord,'beat, pitch, duration, velocity
'''
idx2event = self.vocab.idx2event
feature_keys = self.vocab.feature_list
feature2idx = {key: idx for idx, key in enumerate(feature_keys)}
midi_obj = miditoolkit.midi.parser.MidiFile()
if len(feature2idx) == 4 or len(feature2idx) == 5:
default_tempo = 95
midi_obj.tempo_changes.append(
TempoChange(tempo=default_tempo, time=0))
default_ticks_per_beat = 480
default_in_beat_ticks = 480 // self.in_beat_resolution
cur_pos = 0
bar_pos = 0
cur_bar_resol = 0
beat_pos = 0
instr_notes_dict = defaultdict(list)
generated_output = generated_output.squeeze(0)
for i in range(len(generated_output)):
token_with_7infos = []
for kidx, key in enumerate(feature_keys):
token_with_7infos.append(idx2event[key][generated_output[i][kidx].item()])
# type token
if 'time_signature' in token_with_7infos[feature2idx['type']]:
cur_num, cur_denom = token_with_7infos[feature2idx['type']].split('_')[-1].split('/')
bar_pos += cur_bar_resol
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
elif token_with_7infos[feature2idx['type']] == 'Metrical':
if 'time_signature' in token_with_7infos[feature2idx['beat']]:
cur_num, cur_denom = token_with_7infos[feature2idx['beat']].split('_')[-1].split('/')
bar_pos += cur_bar_resol
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
elif token_with_7infos[feature2idx['beat']] == 'Bar':
bar_pos += cur_bar_resol
elif 'Beat' in str(token_with_7infos[feature2idx['beat']]):
beat_pos = int(token_with_7infos[feature2idx['beat']].split('_')[1])
cur_pos = bar_pos + beat_pos * default_in_beat_ticks # ticks
# chord and tempo
midi_obj = self._update_chord_tempo(midi_obj, cur_pos, token_with_7infos, feature2idx)
elif token_with_7infos[feature2idx['type']] == 'Note':
# instrument token
if len(feature2idx) == 8 or len(feature2idx) == 5:
if token_with_7infos[feature2idx['instrument']] != 0 and token_with_7infos[feature2idx['instrument']] != 'CONTI':
cur_instr = int(token_with_7infos[feature2idx['instrument']].split('_')[-1])
else:
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
try:
pitch = token_with_7infos[feature2idx['pitch']].split('_')[-1]
duration = token_with_7infos[feature2idx['duration']].split('_')[-1]
duration = int(duration) * default_in_beat_ticks
if len(feature2idx) == 7 or len(feature2idx) == 8:
velocity = token_with_7infos[feature2idx['velocity']].split('_')[-1]
else:
velocity = 80
end = cur_pos + duration
instr_notes_dict[cur_instr].append(
Note(
pitch=int(pitch),
start=cur_pos,
end=end,
velocity=int(velocity))
)
except:
continue
else: # when new bar started without beat
continue
# save midi
for instr, notes in instr_notes_dict.items():
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
if instr == 114: is_drum = True
else: is_drum = False
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
instr_track.notes = notes
midi_obj.instruments.append(instr_track)
if isinstance(output_path, str) or isinstance(output_path, Path):
output_path = str(output_path)
output_music_dir = os.path.join(os.path.dirname(output_path), 'music')
if not os.path.exists(output_music_dir):
os.makedirs(output_music_dir)
midi_obj.dump(output_path)
save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, output_music_dir.replace('.mid', '.wav'), gain=self.gain)
return midi_obj
class MidiDecoder4NB(MidiDecoder4REMI):
def __init__(self, vocab, in_beat_resolution, dataset_name):
super().__init__(vocab, in_beat_resolution, dataset_name)
def _update_additional_info(self, midi_obj, cur_pos, token_with_7infos, feature2idx):
if len(feature2idx) == 7 or len(feature2idx) == 8:
# chord
if token_with_7infos[feature2idx['chord']] != 'CONTI' and token_with_7infos[feature2idx['chord']] != 0 and token_with_7infos[feature2idx['chord']] != 'Chord_N_N':
midi_obj.markers.append(
Marker(text=str(token_with_7infos[feature2idx['chord']]), time=cur_pos))
# tempo
if token_with_7infos[feature2idx['tempo']] != 'CONTI' and token_with_7infos[feature2idx['tempo']] != 0 and token_with_7infos[feature2idx['tempo']] != "Tempo_N_N":
tempo = int(token_with_7infos[feature2idx['tempo']].split('_')[-1])
midi_obj.tempo_changes.append(
TempoChange(tempo=tempo, time=cur_pos))
return midi_obj
elif len(feature2idx) == 4 or len(feature2idx) == 5:
return midi_obj
def __call__(self, generated_output, output_path=None):
'''
generated_output: tensor, batch x seq_len x num_types
num_types includes: type, beat, chord, tempo, intrument, pitch, duration, velocity
'''
idx2event = self.vocab.idx2event
feature_keys = self.vocab.feature_list
feature2idx = {key: idx for idx, key in enumerate(feature_keys)}
midi_obj = miditoolkit.midi.parser.MidiFile()
if len(feature2idx) == 4 or len(feature2idx) == 5:
default_tempo = 95
midi_obj.tempo_changes.append(
TempoChange(tempo=default_tempo, time=0))
default_ticks_per_beat = 480
default_in_beat_ticks = 480 // self.in_beat_resolution
cur_pos = 0
bar_pos = 0
cur_bar_resol = 0
beat_pos = 0
instr_notes_dict = defaultdict(list)
generated_output = generated_output.squeeze(0)
for i in range(len(generated_output)):
token_with_7infos = []
for kidx, key in enumerate(feature_keys):
token_with_7infos.append(idx2event[key][generated_output[i][kidx].item()])
# type token
if token_with_7infos[feature2idx['type']] == 'Empty_Bar' or token_with_7infos[feature2idx['type']] == 'SNN':
bar_pos += cur_bar_resol
elif 'NNN' in token_with_7infos[feature2idx['type']]:
cur_num, cur_denom = token_with_7infos[feature2idx['type']].split('_')[-1].split('/')
bar_pos += cur_bar_resol
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
# instrument token
if len(feature2idx) == 8 or len(feature2idx) == 5:
if token_with_7infos[feature2idx['instrument']] != 0 and token_with_7infos[feature2idx['instrument']] != 'CONTI':
cur_instr = int(token_with_7infos[feature2idx['instrument']].split('_')[-1])
else:
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
if 'Beat' in str(token_with_7infos[feature2idx['beat']]) or 'CONTI' in str(token_with_7infos[feature2idx['beat']]):
if 'Beat' in str(token_with_7infos[feature2idx['beat']]): # when beat is not CONTI beat is updated
beat_pos = int(token_with_7infos[feature2idx['beat']].split('_')[1])
cur_pos = bar_pos + beat_pos * default_in_beat_ticks # ticks
# update chord and tempo
midi_obj = self._update_additional_info(midi_obj, cur_pos, token_with_7infos, feature2idx)
# note
try:
pitch = token_with_7infos[feature2idx['pitch']].split('_')[-1]
duration = token_with_7infos[feature2idx['duration']].split('_')[-1] # duration between 1~192
duration = int(duration) * default_in_beat_ticks
if len(feature2idx) == 7 or len(feature2idx) == 8:
velocity = token_with_7infos[feature2idx['velocity']].split('_')[-1]
else:
velocity = 90
end = cur_pos + duration
instr_notes_dict[cur_instr].append(
Note(
pitch=int(pitch),
start=cur_pos,
end=end,
velocity=int(velocity))
)
except:
continue
else: # when new bar started without beat
continue
# save midi
for instr, notes in instr_notes_dict.items():
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
if instr == 114: is_drum = True
else: is_drum = False
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
instr_track.notes = notes
midi_obj.instruments.append(instr_track)
if isinstance(output_path, str) or isinstance(output_path, Path):
output_path = str(output_path)
music_path = os.path.join(os.path.dirname(output_path), 'music')
prompt_music_path = os.path.join(os.path.dirname(output_path), 'prompt_music')
if not os.path.exists(music_path):
os.makedirs(music_path)
if not os.path.exists(prompt_music_path):
os.makedirs(prompt_music_path)
# if not contain 'prompt' in output_path, save prompt music
if 'prompt' in output_path:
music_path = os.path.join(prompt_music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
else:
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
midi_obj.dump(output_path)
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
return midi_obj

View File

@ -0,0 +1,208 @@
import torch
import numpy as np
from collections import Counter
# TODO: refactor hard coded values
def check_syntax_errors_in_inference_for_nb(generated_output, feature_list):
generated_output = generated_output.squeeze(0)
type_idx = feature_list.index('type')
beat_idx = feature_list.index('beat')
type_beat_list = []
for token in generated_output:
type_beat_list.append((token[type_idx].item(), token[beat_idx].item())) # type, beat
last_note = 1
beat_type_unmatched_error_list = []
num_unmatched_errors = 0
beat_backwards_error_list = []
num_backwards_errors = 0
for type_beat in type_beat_list:
if type_beat[0] == 4: # same bar, new beat
if type_beat[1] == 0 or type_beat[1] == 1:
num_unmatched_errors += 1
beat_type_unmatched_error_list.append(type_beat)
if type_beat[1] <= last_note:
num_backwards_errors += 1
beat_backwards_error_list.append([last_note, type_beat])
else:
last_note = type_beat[1] # update last note
elif type_beat[0] >= 5: # new bar, new beat
if type_beat[1] == 0:
num_unmatched_errors += 1
beat_type_unmatched_error_list.append(type_beat)
last_note = 1
unmatched_error_rate = num_unmatched_errors / len(type_beat_list)
backwards_error_rate = num_backwards_errors / len(type_beat_list)
type_beat_errors_dict = {'beat_type_unmatched_error': unmatched_error_rate, 'beat_backwards_error': backwards_error_rate}
return type_beat_errors_dict
def check_syntax_errors_in_inference_for_cp(generated_output, feature_list):
generated_output = generated_output.squeeze(0)
type_idx = feature_list.index('type')
beat_idx = feature_list.index('beat')
pitch_idx = feature_list.index('pitch')
duration_idx = feature_list.index('duration')
last_note = 1
beat_type_unmatched_error_list = []
num_unmatched_errors = 0
beat_backwards_error_list = []
num_backwards_errors = 0
for token in generated_output:
if token[type_idx].item() == 2: # Metrical
if token[pitch_idx].item() != 0 or token[duration_idx].item() != 0:
num_unmatched_errors += 1
beat_type_unmatched_error_list.append(token)
if token[beat_idx].item() == 1: # new bar
last_note = 1 # last note will be updated in the next token
elif token[beat_idx].item() != 0 and token[beat_idx].item() <= last_note:
num_backwards_errors += 1
last_note = token[beat_idx].item() # update last note
beat_backwards_error_list.append([last_note, token])
else:
last_note = token[beat_idx].item() # update last note
if token[type_idx].item() == 3: # Note
if token[beat_idx].item() != 0:
num_unmatched_errors += 1
beat_type_unmatched_error_list.append(token)
unmatched_error_rate = num_unmatched_errors / len(generated_output)
backwards_error_rate = num_backwards_errors / len(generated_output)
type_beat_errors_dict = {'beat_type_unmatched_error': unmatched_error_rate, 'beat_backwards_error': backwards_error_rate}
return type_beat_errors_dict
def check_syntax_errors_in_inference_for_remi(generated_output, vocab):
generated_output = generated_output.squeeze(0)
# to check duration errors
beat_mask = vocab.total_mask['beat'].to(generated_output.device)
beat_mask_for_target = beat_mask[generated_output]
beat_target = generated_output * beat_mask_for_target
bar_mask = vocab.total_mask['type'].to(generated_output.device)
bar_mask_for_target = bar_mask[generated_output]
bar_target = (generated_output+1) * bar_mask_for_target # as bar token in 0 in remi vocab, we add 1 to bar token
target = beat_target + bar_target
target = target[target!=0]
# collect beats in between bars(idx=1)
num_backwards_errors = 0
collected_beats = []
total_beats = 0
for token in target:
if token == 1 or 3 <= token <= 26: # Bar_None, or Bar_time_signature
collected_beats_tensor = torch.tensor(collected_beats)
diff = torch.diff(collected_beats_tensor)
num_error_beats = torch.where(diff<=0)[0].shape[0]
num_backwards_errors += num_error_beats
collected_beats = []
else:
collected_beats.append(token.item())
total_beats += 1
if total_beats != 0:
backwards_error_rate = num_backwards_errors / total_beats
else:
backwards_error_rate = 0
# print(f"error rate in beat backwards: {backwards_error_rate}")
return {'beat_backwards_error': backwards_error_rate}
def type_beat_errors_in_validation_nb(beat_prob, answer_type, input_beat, mask):
bool_mask = mask.bool().flatten() # (b*t)
pred_beat_idx = torch.argmax(beat_prob, dim=-1).flatten() # (b*t)
valid_pred_beat_idx = pred_beat_idx[bool_mask] # valid beat_idx
answer_type = answer_type.flatten() # (b*t)
valid_type_input = answer_type[bool_mask] # valid answer_type
type_beat_list = []
for i in range(len(valid_pred_beat_idx)):
type_beat_list.append((valid_type_input[i].item(), valid_pred_beat_idx[i].item())) # type, beat
input_beat = input_beat.flatten()
valid_input_beat = input_beat[bool_mask]
last_note = 1
num_unmatched_errors = 0
num_backwards_errors = 0
for type_beat, input_beat_idx in zip(type_beat_list, valid_input_beat):
# update last note
if input_beat_idx.item() >= 1: # beat
last_note = input_beat_idx.item()
if type_beat[0] == 4: # same bar, new beat
if type_beat[1] == 0 or type_beat[1] == 1:
num_unmatched_errors += 1
if type_beat[1] <= last_note:
num_backwards_errors += 1
elif type_beat[0] >= 5: # new bar, new beat
if type_beat[1] == 0:
num_unmatched_errors += 1
return len(type_beat_list), num_unmatched_errors, num_backwards_errors
def type_beat_errors_in_validation_cp(beat_prob, answer_type, input_beat, mask):
bool_mask = mask.bool().flatten() # (b*t)
beat_idx = torch.argmax(beat_prob, dim=-1).flatten() # (b*t)
valid_beat_idx = beat_idx[bool_mask] # valid beat_idx
answer_type = answer_type.flatten() # (b*t)
valid_type_input = answer_type[bool_mask] # valid answer_type
type_beat_list = []
for i in range(len(valid_beat_idx)):
type_beat_list.append((valid_type_input[i].item(), valid_beat_idx[i].item())) # type, beat
input_beat = input_beat.flatten()
valid_input_beat = input_beat[bool_mask]
last_note = 1
num_unmatched_errors = 0
num_backwards_errors = 0
for type_beat, input_beat_idx in zip(type_beat_list, valid_input_beat):
# update last note
if input_beat_idx.item() == 1: # bar
last_note = 1
elif input_beat_idx.item() >= 2: # new beat
last_note = input_beat_idx.item()
# check errors
if type_beat[0] == 2: # Metrical
if type_beat[1] == 0: # ignore
num_unmatched_errors += 1
elif type_beat[1] >= 2: # new beat
if type_beat[1] <= last_note:
num_backwards_errors += 1
elif type_beat[0] == 3: # Note
if type_beat[1] != 0:
num_unmatched_errors += 1
return len(type_beat_list), num_unmatched_errors, num_backwards_errors
def get_beat_difference_metric(prob_dict, arranged_prob_dict, mask):
orign_beat_prob = prob_dict['beat'] # b x t x vocab_size
arranged_beat_prob = arranged_prob_dict['beat'] # b x t x vocab_size
# calculate similarity between original beat prob and arranged beat prob
origin_beat_token = torch.argmax(orign_beat_prob, dim=-1) * mask # b x t
arranged_beat_token = torch.argmax(arranged_beat_prob, dim=-1) * mask # b x t
num_same_beat = torch.sum(origin_beat_token == arranged_beat_token) - torch.sum(mask==0)
num_beat = torch.sum(mask==1)
beat_sim = (num_same_beat / num_beat).item() # scalar
# apply mask, shape of mask: b x t
orign_beat_prob = orign_beat_prob * mask.unsqueeze(-1) # b x t x vocab_size
arranged_beat_prob = arranged_beat_prob * mask.unsqueeze(-1)
# calculate cosine similarity between original beat prob and arranged beat prob
orign_beat_prob = orign_beat_prob.flatten(0,1) # (b*t) x vocab_size
arranged_beat_prob = arranged_beat_prob.flatten(0,1) # (b*t) x vocab_size
cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
beat_cos_sim = cos(orign_beat_prob, arranged_beat_prob) # (b*t)
# exclude invalid tokens, zero padding tokens
beat_cos_sim = beat_cos_sim[mask.flatten().bool()] # num_valid_tokens
beat_cos_sim = torch.mean(beat_cos_sim).item() # scalar
return {'beat_cos_sim': beat_cos_sim, 'beat_sim': beat_sim}
def get_gini_coefficient(generated_output):
if len(generated_output.shape) == 3:
generated_output = generated_output.squeeze(0).tolist()
gen_list = [tuple(x) for x in generated_output]
else:
gen_list = generated_output.squeeze(0).tolist()
counts = Counter(gen_list).values()
sorted_counts = sorted(counts)
n = len(sorted_counts)
cumulative_counts = np.cumsum(sorted_counts)
cumulative_proportion = cumulative_counts / cumulative_counts[-1]
lorenz_area = sum(cumulative_proportion[:-1]) / n # Exclude the last element
equality_area = 0.5 # The area under line of perfect equality
gini = (equality_area - lorenz_area) / equality_area
return gini

View File

@ -0,0 +1,78 @@
import argparse
import os
import subprocess
from pydub import AudioSegment
'''
This file is a modified version of midi2audio.py from https://github.com/bzamecnik/midi2audio
Author: Bohumír Zámečník (@bzamecnik)
License: MIT, see the LICENSE file
'''
__all__ = ['FluidSynth']
DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2'
DEFAULT_SAMPLE_RATE = 48000
DEFAULT_GAIN = 0.05
# DEFAULT_SOUND_FONT = "/data2/suhongju/research/music-generation/sound_file/Advent GM 7.sf2"
# DEFAULT_SOUND_FONT = '~/.fluidsynth/default_sound_font.sf2'
# DEFAULT_SAMPLE_RATE = 16000
# DEFAULT_GAIN = 0.20
class FluidSynth():
def __init__(self, sound_font=DEFAULT_SOUND_FONT, sample_rate=DEFAULT_SAMPLE_RATE, gain=DEFAULT_GAIN):
self.sample_rate = sample_rate
self.sound_font = os.path.expanduser(sound_font)
self.gain = gain
def midi_to_audio(self, midi_file: str, audio_file: str, verbose=True):
if verbose:
stdout = None
else:
stdout = subprocess.DEVNULL
# Convert MIDI to WAV
subprocess.call(
['fluidsynth', '-ni', '-g', str(self.gain), self.sound_font, midi_file, '-F', audio_file, '-r', str(self.sample_rate)],
stdout=stdout
)
# Convert WAV to MP3
# mp3_path = audio_file.replace('.wav', '.mp3')
# AudioSegment.from_wav(audio_file).export(mp3_path, format="mp3")
# # Delete the temporary WAV file
# os.remove(audio_file)
def play_midi(self, midi_file):
subprocess.call(['fluidsynth', '-i', '-g', str(self.gain), self.sound_font, midi_file, '-r', str(self.sample_rate)])
def parse_args(allow_synth=True):
parser = argparse.ArgumentParser(description='Convert MIDI to audio via FluidSynth')
parser.add_argument('midi_file', metavar='MIDI', type=str)
if allow_synth:
parser.add_argument('audio_file', metavar='AUDIO', type=str, nargs='?')
parser.add_argument('-s', '--sound-font', type=str,
default=DEFAULT_SOUND_FONT,
help='path to a SF2 sound font (default: %s)' % DEFAULT_SOUND_FONT)
parser.add_argument('-r', '--sample-rate', type=int, nargs='?',
default=DEFAULT_SAMPLE_RATE,
help='sample rate in Hz (default: %s)' % DEFAULT_SAMPLE_RATE)
return parser.parse_args()
def main(allow_synth=True):
args = parse_args(allow_synth)
fs = FluidSynth(args.sound_font, args.sample_rate)
if allow_synth and args.audio_file:
fs.midi_to_audio(args.midi_file, args.audio_file)
else:
fs.play_midi(args.midi_file)
def main_play():
"""
A method for the `midiplay` entry point. It omits the audio file from args.
"""
main(allow_synth=False)
if __name__ == '__main__':
main()