first commit
This commit is contained in:
0
Amadeus/symbolic_encoding/__init__.py
Normal file
0
Amadeus/symbolic_encoding/__init__.py
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/augmentor.cpython-310.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/augmentor.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/augmentor.cpython-311.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/augmentor.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/data_utils.cpython-310.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/data_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/data_utils.cpython-311.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/data_utils.cpython-311.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/midi2audio.cpython-310.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/midi2audio.cpython-310.pyc
Normal file
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/midi2audio.cpython-311.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/midi2audio.cpython-311.pyc
Normal file
Binary file not shown.
BIN
Amadeus/symbolic_encoding/__pycache__/midi2audio.cpython-312.pyc
Normal file
BIN
Amadeus/symbolic_encoding/__pycache__/midi2audio.cpython-312.pyc
Normal file
Binary file not shown.
46
Amadeus/symbolic_encoding/anylazesf.py
Normal file
46
Amadeus/symbolic_encoding/anylazesf.py
Normal file
@ -0,0 +1,46 @@
|
||||
from sf2utils.sf2parse import Sf2File
|
||||
|
||||
def print_sorted_presets(sf2_path):
|
||||
presets_info = []
|
||||
|
||||
with open(sf2_path, 'rb') as f:
|
||||
sf2 = Sf2File(f)
|
||||
|
||||
for preset in sf2.presets:
|
||||
try:
|
||||
# 尝试直接读取
|
||||
name = getattr(preset, 'name', '???').strip('\x00')
|
||||
bank = getattr(preset, 'bank', None)
|
||||
program = getattr(preset, 'preset', None)
|
||||
|
||||
# 如果获取不到,再尝试从子属性中取
|
||||
if bank is None or program is None:
|
||||
for attr in dir(preset):
|
||||
attr_value = getattr(preset, attr)
|
||||
if hasattr(attr_value, 'bank') and hasattr(attr_value, 'preset'):
|
||||
bank = attr_value.bank
|
||||
program = attr_value.preset
|
||||
name = getattr(attr_value, 'name', name).strip('\x00')
|
||||
break
|
||||
|
||||
# 收集有效结果
|
||||
if bank is not None and program is not None:
|
||||
presets_info.append((program, bank, name))
|
||||
except Exception as e:
|
||||
print(f"Error reading preset: {e}")
|
||||
|
||||
# 按 program 升序排序(若需要按 bank 再 program,改为 sorted(..., key=lambda x: (x[1], x[0])))
|
||||
presets_info.sort(key=lambda x: x[0])
|
||||
|
||||
# 打印结果
|
||||
print(f"{'Program':<8} {'Bank':<6} {'Preset Name'}")
|
||||
print("-" * 40)
|
||||
for program, bank, name in presets_info:
|
||||
print(f"{program:<8} {bank:<6} {name}")
|
||||
|
||||
# DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2'
|
||||
# DEFAULT_SOUND_FONT = '~/.fluidsynth/default_sound_font.sf2'
|
||||
|
||||
# 替换为你的 sf2 文件路径
|
||||
sf2_path = "/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2"
|
||||
print_sorted_presets(sf2_path)
|
||||
94
Amadeus/symbolic_encoding/augmentor.py
Normal file
94
Amadeus/symbolic_encoding/augmentor.py
Normal file
@ -0,0 +1,94 @@
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
class Augmentor:
|
||||
def __init__(
|
||||
self,
|
||||
vocab,
|
||||
aug_type:Union[str, None],
|
||||
input_length:int
|
||||
):
|
||||
self.vocab = vocab
|
||||
self.aug_type = aug_type
|
||||
self.input_length = input_length
|
||||
self.feature_list = vocab.feature_list
|
||||
self.num_features = len(self.feature_list)
|
||||
self.encoding_scheme = vocab.encoding_scheme
|
||||
|
||||
self.pitch_idx = self.feature_list.index('pitch')
|
||||
if 'chord' in self.feature_list:
|
||||
self.chord_idx = self.feature_list.index('chord')
|
||||
|
||||
def _get_shift(self, segment):
|
||||
# the pitch vocab has ignore token in 0 index
|
||||
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
|
||||
pitch_mask = segment != 0
|
||||
pitch_segment = segment[pitch_mask[:,self.pitch_idx], self.pitch_idx]
|
||||
# check if tensor is empty
|
||||
if pitch_segment.numel() == 0:
|
||||
shift = 0
|
||||
else:
|
||||
lowest_pitch = max(12, torch.min(pitch_segment))
|
||||
highest_pitch = min(119, torch.max(pitch_segment))
|
||||
lower_shift_bound = torch.where(lowest_pitch - torch.arange(6) > 11)[0][-1].item()
|
||||
upper_shift_bound = torch.where(highest_pitch + torch.arange(7) < 120)[0][-1].item()
|
||||
shift = random.randint(-lower_shift_bound, upper_shift_bound)
|
||||
else: # remi
|
||||
mask_for_pitch = self.vocab.total_mask['pitch'].to(segment.device)
|
||||
segemnt_pitch_mask = mask_for_pitch[segment]
|
||||
segment_pitch = segment * segemnt_pitch_mask
|
||||
segment_pitch = segment_pitch[segment_pitch != 0]
|
||||
# check if tensor is empty
|
||||
if segment_pitch.numel() == 0:
|
||||
shift = 0
|
||||
else:
|
||||
lower_bound = torch.argwhere(mask_for_pitch == 1)[0].item()
|
||||
upper_bound = torch.argwhere(mask_for_pitch == 1)[-1].item()
|
||||
lowest_pitch = max(lower_bound, torch.min(segment_pitch))
|
||||
highest_pitch = min(upper_bound, torch.max(segment_pitch))
|
||||
lower_shift_bound = torch.where(lowest_pitch - torch.arange(6) >= lower_bound)[0][-1].item()
|
||||
upper_shift_bound = torch.where(highest_pitch + torch.arange(7) <= upper_bound)[0][-1].item()
|
||||
shift = random.randint(-lower_shift_bound, upper_shift_bound)
|
||||
return shift
|
||||
|
||||
# TODO: arrange hard coded part
|
||||
def __call__(self, segment):
|
||||
'''
|
||||
input_tensor is segments of x, y
|
||||
for transformer_xl, the shape of x, y is [max_num_segments, input_length, num_features]
|
||||
so we need to change the shape of x, y to [max_num_segments*input_length, num_features]
|
||||
'''
|
||||
if self.aug_type == 'random':
|
||||
shift = self._get_shift(segment)
|
||||
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
|
||||
# pitch augmentation
|
||||
segment_pitch_mask = segment != 0
|
||||
new_segment = segment.clone()
|
||||
new_segment[segment_pitch_mask[:,self.pitch_idx], self.pitch_idx] += shift
|
||||
if 'chord' in self.feature_list:
|
||||
# chord augmentation
|
||||
segment_chord_mask = (segment[:,self.chord_idx] != 0) & (segment[:,self.chord_idx] != 1)
|
||||
new_segment[segment_chord_mask, self.chord_idx] = (((new_segment[segment_chord_mask, self.chord_idx]-2) % 12) + shift ) % 12 + ((new_segment[segment_chord_mask, self.chord_idx]-2) // 12) * 12 + 2
|
||||
segment = new_segment
|
||||
else: # remi
|
||||
# choose random interger between -5 and 6
|
||||
# the augmented results from shift -6 and 6 are same, so we choose -5 and 6
|
||||
# pitch augmentation
|
||||
mask_for_pitch = self.vocab.total_mask['pitch'].to(segment.device)
|
||||
segment_pitch_mask = mask_for_pitch[segment]
|
||||
new_segment = segment.clone()
|
||||
new_segment_valid = (new_segment + shift) * segment_pitch_mask
|
||||
new_segment = new_segment * (1 - segment_pitch_mask) + new_segment_valid
|
||||
if 'chord' in self.feature_list:
|
||||
# chord augmentation
|
||||
mask_for_chord = self.vocab.total_mask['chord'].clone().to(segment.device)
|
||||
chord_n_n_idx = torch.argwhere(mask_for_chord == 1)[-1].item()
|
||||
mask_for_chord[chord_n_n_idx] = 0
|
||||
start_idx_chord = self.vocab.remi_vocab_boundaries_by_key['chord'][0]
|
||||
segment_chord_mask = mask_for_chord[segment]
|
||||
new_segment_valid = ((((new_segment - start_idx_chord) % 12 + shift) % 12) + ((new_segment - start_idx_chord) // 12) * 12 + start_idx_chord) * segment_chord_mask
|
||||
new_segment = new_segment * (1 - segment_chord_mask) + new_segment_valid
|
||||
segment = new_segment
|
||||
return segment
|
||||
207
Amadeus/symbolic_encoding/compile_utils.py
Normal file
207
Amadeus/symbolic_encoding/compile_utils.py
Normal file
@ -0,0 +1,207 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
def reverse_shift_and_pad(tune_in_idx, slice_boundary=4):
|
||||
new_lst = [curr_elems[:slice_boundary] + next_elems[slice_boundary:] for curr_elems, next_elems in zip(tune_in_idx, tune_in_idx[1:])]
|
||||
return new_lst
|
||||
|
||||
def reverse_shift_and_pad_for_tensor(tensor, first_pred_feature):
|
||||
'''
|
||||
tensor: [batch_size x seq_len x feature_size]
|
||||
'''
|
||||
if first_pred_feature == 'type':
|
||||
return tensor
|
||||
if tensor.shape[-1] == 8:
|
||||
slice_boundary_dict = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'instrument':4, 'pitch':5, 'duration':6, 'velocity':7}
|
||||
elif tensor.shape[-1] == 7:
|
||||
slice_boundary_dict = {'type':0, 'beat':1, 'chord':2, 'tempo':3, 'pitch':4, 'duration':5, 'velocity':6}
|
||||
elif tensor.shape[-1] == 5:
|
||||
slice_boundary_dict = {'type':0, 'beat':1, 'instrument':2, 'pitch':3, 'duration':4}
|
||||
elif tensor.shape[-1] == 4:
|
||||
slice_boundary_dict = {'type':0, 'beat':1, 'pitch':2, 'duration':3}
|
||||
slice_boundary = slice_boundary_dict[first_pred_feature]
|
||||
new_tensor = torch.zeros_like(tensor)
|
||||
new_tensor[..., :, :slice_boundary] = tensor[..., :, :slice_boundary]
|
||||
new_tensor[..., :-1, slice_boundary:] = tensor[..., 1:, slice_boundary:]
|
||||
return new_tensor
|
||||
|
||||
def shift_and_pad(tune_in_idx, first_pred_feature):
|
||||
if first_pred_feature == 'type':
|
||||
return tune_in_idx
|
||||
if len(tune_in_idx[0]) == 8:
|
||||
slice_boundary_dict = {'type':0, 'beat':-7, 'chord':-6, 'tempo':-5, 'instrument':-4, 'pitch':-3, 'duration':-2, 'velocity':-1}
|
||||
elif len(tune_in_idx[0]) == 7:
|
||||
slice_boundary_dict = {'type':0, 'beat':-6, 'chord':-5, 'tempo':-4, 'pitch':-3, 'duration':-2, 'velocity':-1}
|
||||
elif len(tune_in_idx[0]) == 5:
|
||||
slice_boundary_dict = {'type':0, 'beat':-4, 'instrument':-3, 'pitch':-2, 'duration':-1}
|
||||
elif len(tune_in_idx[0]) == 4:
|
||||
slice_boundary_dict = {'type':0, 'beat':-3, 'pitch':-2, 'duration':-1}
|
||||
slice_boundary = slice_boundary_dict[first_pred_feature]
|
||||
# Add an empty list padded with zeros at the beginning, and sos and eos tokens are not shifted
|
||||
padded_tune_in_idx = torch.cat([torch.zeros(1, len(tune_in_idx[0]), dtype=torch.long), tune_in_idx], dim=0)
|
||||
new_tensor = torch.zeros_like(padded_tune_in_idx)
|
||||
new_tensor[:, slice_boundary:] = padded_tune_in_idx[:, slice_boundary:]
|
||||
new_tensor[:-1, :slice_boundary] = padded_tune_in_idx[1:, :slice_boundary]
|
||||
return new_tensor
|
||||
|
||||
class VanillaTransformer_compiler():
|
||||
def __init__(
|
||||
self,
|
||||
data_list,
|
||||
augmentor,
|
||||
eos_token,
|
||||
input_length,
|
||||
first_pred_feature,
|
||||
encoding_scheme
|
||||
):
|
||||
self.data_list = data_list
|
||||
self.augmentor = augmentor
|
||||
self.eos_token = eos_token
|
||||
self.input_length = input_length
|
||||
self.first_pred_feature = first_pred_feature
|
||||
self.encoding_scheme = encoding_scheme
|
||||
|
||||
def make_segments(self, data_type):
|
||||
segments = []
|
||||
tune_name2segment = defaultdict(list)
|
||||
segment2tune_name = []
|
||||
num_segments = 0
|
||||
for i in range(len(self.data_list)):
|
||||
tune_in_idx, tune_name = self.data_list[i]
|
||||
tune_in_idx = torch.LongTensor(tune_in_idx)
|
||||
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
else:
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
# shift and pad
|
||||
tune_in_idx = shift_and_pad(tune_in_idx, self.first_pred_feature)
|
||||
if data_type == 'train':
|
||||
if len(tune_in_idx) <= self.input_length+1:
|
||||
if 'remi' in self.encoding_scheme:
|
||||
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx))
|
||||
else:
|
||||
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx), 1)
|
||||
mask = torch.cat([torch.ones(len(tune_in_idx), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
|
||||
segment = torch.cat([tune_in_idx, padding_seq], dim=0)
|
||||
segments.append([segment, mask])
|
||||
segment2tune_name.append(tune_name)
|
||||
else:
|
||||
start_point = 0
|
||||
while start_point + self.input_length+1 < len(tune_in_idx):
|
||||
mask = torch.ones(self.input_length+1, dtype=torch.long)
|
||||
segment = tune_in_idx[start_point:start_point + self.input_length+1]
|
||||
segments.append([segment, mask])
|
||||
segment2tune_name.append(tune_name)
|
||||
assert len(segment) == self.input_length+1
|
||||
# Randomly choose the start point for the next segment, which is in the range of half of the current segment to the end of the current segment
|
||||
start_point += random.randint((self.input_length+1)//2, self.input_length+1)
|
||||
# if text controled,we only use the first segment
|
||||
# add the last segment
|
||||
if len(tune_in_idx[start_point:]) < self.input_length+1:
|
||||
if 'remi' in self.encoding_scheme:
|
||||
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx[start_point:]))
|
||||
else:
|
||||
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx[start_point:]), 1)
|
||||
mask = torch.cat([torch.ones(len(tune_in_idx[start_point:]), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
|
||||
segment = torch.cat([tune_in_idx[start_point:], padding_seq], dim=0)
|
||||
segments.append([segment, mask])
|
||||
segment2tune_name.append(tune_name)
|
||||
|
||||
|
||||
else: # for validset
|
||||
for i in range(0, len(tune_in_idx), self.input_length+1):
|
||||
segment = tune_in_idx[i:i+self.input_length+1]
|
||||
if len(segment) <= self.input_length+1:
|
||||
if 'remi' in self.encoding_scheme:
|
||||
padding_seq = eos_token[0].repeat(self.input_length+1-len(segment))
|
||||
else:
|
||||
padding_seq = eos_token.repeat(self.input_length+1-len(segment), 1)
|
||||
mask = torch.cat([torch.ones(len(segment), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
|
||||
segment = torch.cat([segment, padding_seq], dim=0)
|
||||
segment2tune_name.append(tune_name)
|
||||
segments.append([segment, mask])
|
||||
num_segments += 1
|
||||
tune_name2segment[tune_name].append(num_segments-1)
|
||||
else:
|
||||
mask = torch.ones(self.input_length+1, dtype=torch.long)
|
||||
segments.append([segment, mask])
|
||||
segment2tune_name.append(tune_name)
|
||||
segments.append([segment, mask])
|
||||
num_segments += 1
|
||||
tune_name2segment[tune_name].append(num_segments-1)
|
||||
assert len(segment) == self.input_length+1
|
||||
|
||||
return segments, tune_name2segment, segment2tune_name
|
||||
|
||||
def make_segments_iters(self, data_type):
|
||||
tune_name2segment = defaultdict(list)
|
||||
segment2tune_name = []
|
||||
num_segments = 0
|
||||
# shuffle the data_list
|
||||
if data_type == 'train':
|
||||
random.shuffle(self.data_list)
|
||||
print("length of data_list:", len(self.data_list))
|
||||
for i in range(len(self.data_list)):
|
||||
tune_in_idx, tune_name = self.data_list[i]
|
||||
tune_in_idx = torch.LongTensor(tune_in_idx)
|
||||
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
else:
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
# shift and pad
|
||||
tune_in_idx = shift_and_pad(tune_in_idx, self.first_pred_feature)
|
||||
if data_type == 'train':
|
||||
if len(tune_in_idx) <= self.input_length+1:
|
||||
if 'remi' in self.encoding_scheme:
|
||||
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx))
|
||||
else:
|
||||
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx), 1)
|
||||
mask = torch.cat([torch.ones(len(tune_in_idx), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
|
||||
segment = torch.cat([tune_in_idx, padding_seq], dim=0)
|
||||
segment2tune_name.append(tune_name)
|
||||
yield [segment, mask], tune_name2segment, segment2tune_name
|
||||
else:
|
||||
start_point = 0
|
||||
while start_point + self.input_length+1 < len(tune_in_idx):
|
||||
mask = torch.ones(self.input_length+1, dtype=torch.long)
|
||||
segment = tune_in_idx[start_point:start_point + self.input_length+1]
|
||||
segment2tune_name.append(tune_name)
|
||||
yield [segment, mask], tune_name2segment, segment2tune_name
|
||||
assert len(segment) == self.input_length+1
|
||||
start_point += random.randint((self.input_length+1)//2, self.input_length+1)
|
||||
# break
|
||||
if len(tune_in_idx[start_point:]) < self.input_length+1:
|
||||
if 'remi' in self.encoding_scheme:
|
||||
padding_seq = eos_token[0].repeat(self.input_length+1-len(tune_in_idx[start_point:]))
|
||||
else:
|
||||
padding_seq = eos_token.repeat(self.input_length+1-len(tune_in_idx[start_point:]), 1)
|
||||
mask = torch.cat([torch.ones(len(tune_in_idx[start_point:]), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
|
||||
segment = torch.cat([tune_in_idx[start_point:], padding_seq], dim=0)
|
||||
segment2tune_name.append(tune_name)
|
||||
yield [segment, mask], tune_name2segment, segment2tune_name
|
||||
else: # for validset
|
||||
for i in range(0, len(tune_in_idx), self.input_length+1):
|
||||
segment = tune_in_idx[i:i+self.input_length+1]
|
||||
if len(segment) <= self.input_length+1:
|
||||
if 'remi' in self.encoding_scheme:
|
||||
padding_seq = eos_token[0].repeat(self.input_length+1-len(segment))
|
||||
else:
|
||||
padding_seq = eos_token.repeat(self.input_length+1-len(segment), 1)
|
||||
mask = torch.cat([torch.ones(len(segment), dtype=torch.long), torch.zeros(len(padding_seq), dtype=torch.long)], dim=0)
|
||||
segment = torch.cat([segment, padding_seq], dim=0)
|
||||
segment2tune_name.append(tune_name)
|
||||
num_segments += 1
|
||||
tune_name2segment[tune_name].append(num_segments-1)
|
||||
yield [segment, mask], tune_name2segment, segment2tune_name
|
||||
else:
|
||||
mask = torch.ones(self.input_length+1, dtype=torch.long)
|
||||
segment2tune_name.append(tune_name)
|
||||
num_segments += 1
|
||||
tune_name2segment[tune_name].append(num_segments-1)
|
||||
yield [segment, mask], tune_name2segment, segment2tune_name
|
||||
assert len(segment) == self.input_length+1
|
||||
|
||||
1610
Amadeus/symbolic_encoding/data_utils.py
Normal file
1610
Amadeus/symbolic_encoding/data_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
404
Amadeus/symbolic_encoding/decoding_utils.py
Normal file
404
Amadeus/symbolic_encoding/decoding_utils.py
Normal file
@ -0,0 +1,404 @@
|
||||
import os, sys
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import defaultdict
|
||||
|
||||
from music21 import converter
|
||||
import muspy
|
||||
import miditoolkit
|
||||
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature
|
||||
|
||||
from .midi2audio import FluidSynth
|
||||
from data_representation.constants import PROGRAM_INSTRUMENT_MAP
|
||||
|
||||
class MuteWarn:
|
||||
def __enter__(self):
|
||||
self._init_stdout = sys.stdout
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
sys.stdout.close()
|
||||
sys.stdout = self._init_stdout
|
||||
|
||||
def save_score_image_from_midi(midi_fn, file_name):
|
||||
assert isinstance(midi_fn, str)
|
||||
with MuteWarn():
|
||||
convert = converter.parse(midi_fn)
|
||||
convert.write('musicxml.png', fp=file_name)
|
||||
|
||||
def save_pianoroll_image_from_midi(midi_fn, file_name):
|
||||
assert isinstance(midi_fn, str)
|
||||
midi_obj_muspy = muspy.read_midi(midi_fn)
|
||||
midi_obj_muspy.show_pianoroll(track_label='program', preset='frame')
|
||||
plt.gcf().set_size_inches(20, 10)
|
||||
plt.savefig(file_name)
|
||||
plt.close()
|
||||
|
||||
def save_wav_from_midi(midi_fn, file_name, qpm=80):
|
||||
assert isinstance(midi_fn, str)
|
||||
with MuteWarn():
|
||||
music = muspy.read_midi(midi_fn)
|
||||
music.tempos[0].qpm = qpm
|
||||
music.write_audio(file_name, rate=44100, gain=3)
|
||||
|
||||
def save_wav_from_midi_fluidsynth(midi_fn, file_name, gain=3):
|
||||
assert isinstance(midi_fn, str)
|
||||
fs = FluidSynth(gain=gain)
|
||||
fs.midi_to_audio(midi_fn, file_name)
|
||||
|
||||
class MidiDecoder4REMI:
|
||||
def __init__(
|
||||
self,
|
||||
vocab,
|
||||
in_beat_resolution,
|
||||
dataset_name
|
||||
):
|
||||
self.vocab = vocab
|
||||
self.in_beat_resolution = in_beat_resolution
|
||||
self.dataset_name = dataset_name
|
||||
if dataset_name == 'SymphonyMIDI':
|
||||
self.gain = 0.7
|
||||
elif dataset_name == 'SOD' or dataset_name == 'LakhClean':
|
||||
self.gain = 1.1
|
||||
elif dataset_name == 'Pop1k7' or dataset_name == 'Pop909':
|
||||
self.gain = 2.5
|
||||
else:
|
||||
self.gain = 1.5
|
||||
|
||||
def __call__(self, generated_output, output_path=None):
|
||||
'''
|
||||
generated_output: list of tensor, the tensor
|
||||
'''
|
||||
idx2event = self.vocab.idx2event
|
||||
if generated_output.dim() == 2:
|
||||
generated_output = generated_output.squeeze(0)
|
||||
events = [idx2event[token.item()] for token in generated_output]
|
||||
|
||||
midi_obj = miditoolkit.midi.parser.MidiFile()
|
||||
if 'tempo' not in idx2event.keys():
|
||||
default_tempo = 95
|
||||
midi_obj.tempo_changes.append(
|
||||
TempoChange(tempo=default_tempo, time=0))
|
||||
default_ticks_per_beat = 480
|
||||
default_in_beat_ticks = 480 // self.in_beat_resolution
|
||||
cur_pos = 0
|
||||
bar_pos = 0
|
||||
cur_bar_resol = 0
|
||||
beat_pos = 0
|
||||
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
|
||||
instr_notes_dict = defaultdict(list)
|
||||
for i in range(len(events)-2):
|
||||
cur_event = events[i]
|
||||
# print(cur_event)
|
||||
name = cur_event.split('_')[0]
|
||||
attr = cur_event.split('_')
|
||||
if name == 'Bar':
|
||||
bar_pos += cur_bar_resol
|
||||
if 'time' in cur_event:
|
||||
cur_num, cur_denom = attr[-1].split('/')
|
||||
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
|
||||
cur_bar_resol = new_bar_resol
|
||||
midi_obj.time_signature_changes.append(
|
||||
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
|
||||
elif name == 'Beat':
|
||||
beat_pos = int(attr[1])
|
||||
cur_pos = bar_pos + beat_pos * default_in_beat_ticks
|
||||
elif name == 'Chord':
|
||||
chord_text = attr[1] + '_' + attr[2]
|
||||
midi_obj.markers.append(Marker(text=chord_text, time=cur_pos))
|
||||
elif name == 'Tempo':
|
||||
midi_obj.tempo_changes.append(
|
||||
TempoChange(tempo=int(attr[1]), time=cur_pos))
|
||||
elif name == 'Instrument':
|
||||
cur_instr = int(attr[1])
|
||||
else:
|
||||
if len(self.vocab.feature_list) == 7 or len(self.vocab.feature_list) == 8:
|
||||
if 'Note_Pitch' in events[i] and \
|
||||
'Note_Duration' in events[i+1] and \
|
||||
'Note_Velocity' in events[i+2]:
|
||||
pitch = int(events[i].split('_')[-1])
|
||||
duration = int(events[i+1].split('_')[-1])
|
||||
duration = duration * default_in_beat_ticks
|
||||
end = cur_pos + duration
|
||||
velocity = int(events[i+2].split('_')[-1])
|
||||
instr_notes_dict[cur_instr].append(
|
||||
Note(
|
||||
pitch=pitch,
|
||||
start=cur_pos,
|
||||
end=end,
|
||||
velocity=velocity))
|
||||
elif len(self.vocab.feature_list) == 4 or len(self.vocab.feature_list) == 5:
|
||||
if 'Note_Pitch' in events[i] and \
|
||||
'Note_Duration' in events[i+1]:
|
||||
pitch = int(events[i].split('_')[-1])
|
||||
duration = int(events[i+1].split('_')[-1])
|
||||
duration = duration * default_in_beat_ticks
|
||||
end = cur_pos + duration
|
||||
velocity = 90
|
||||
instr_notes_dict[cur_instr].append(
|
||||
Note(
|
||||
pitch=pitch,
|
||||
start=cur_pos,
|
||||
end=end,
|
||||
velocity=velocity))
|
||||
|
||||
# save midi
|
||||
for instr, notes in instr_notes_dict.items():
|
||||
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
|
||||
if instr == 114: is_drum = True
|
||||
else: is_drum = False
|
||||
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
|
||||
instr_track.notes = notes
|
||||
midi_obj.instruments.append(instr_track)
|
||||
|
||||
if isinstance(output_path, str) or isinstance(output_path, Path):
|
||||
output_path = str(output_path)
|
||||
# make subdir
|
||||
music_path = os.path.join(os.path.dirname(output_path), 'music')
|
||||
prompt_music_path = os.path.join(os.path.dirname(output_path), 'prompt_music')
|
||||
if not os.path.exists(music_path):
|
||||
os.makedirs(music_path)
|
||||
if not os.path.exists(prompt_music_path):
|
||||
os.makedirs(prompt_music_path)
|
||||
# if not contain 'prompt' in output_path, save prompt music
|
||||
if 'prompt' in output_path:
|
||||
music_path = os.path.join(prompt_music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||||
else:
|
||||
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||||
|
||||
midi_obj.dump(output_path)
|
||||
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
|
||||
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
|
||||
return midi_obj
|
||||
|
||||
class MidiDecoder4CP(MidiDecoder4REMI):
|
||||
def __init__(self, vocab, in_beat_resolution, dataset_name):
|
||||
super().__init__(vocab, in_beat_resolution, dataset_name)
|
||||
|
||||
def _update_chord_tempo(self, midi_obj, cur_pos, token_with_7infos, feature2idx):
|
||||
if len(feature2idx) == 7 or len(feature2idx) == 8:
|
||||
# chord
|
||||
if token_with_7infos[feature2idx['chord']] != 'CONTI' and token_with_7infos[feature2idx['chord']] != 0:
|
||||
midi_obj.markers.append(
|
||||
Marker(text=str(token_with_7infos[feature2idx['chord']]), time=cur_pos))
|
||||
# tempo
|
||||
if token_with_7infos[feature2idx['tempo']] != 'CONTI' and token_with_7infos[feature2idx['tempo']] != 0 and token_with_7infos[feature2idx['tempo']] != "Tempo_N_N":
|
||||
tempo = int(token_with_7infos[feature2idx['tempo']].split('_')[-1])
|
||||
midi_obj.tempo_changes.append(
|
||||
TempoChange(tempo=tempo, time=cur_pos))
|
||||
return midi_obj
|
||||
elif len(feature2idx) == 4 or len(feature2idx) == 5:
|
||||
return midi_obj
|
||||
|
||||
def __call__(self, generated_output, output_path=None):
|
||||
'''
|
||||
generated_output: tensor, batch x seq_len x num_types
|
||||
num_types includes: type, tempo, chord,'beat, pitch, duration, velocity
|
||||
'''
|
||||
idx2event = self.vocab.idx2event
|
||||
feature_keys = self.vocab.feature_list
|
||||
feature2idx = {key: idx for idx, key in enumerate(feature_keys)}
|
||||
|
||||
midi_obj = miditoolkit.midi.parser.MidiFile()
|
||||
if len(feature2idx) == 4 or len(feature2idx) == 5:
|
||||
default_tempo = 95
|
||||
midi_obj.tempo_changes.append(
|
||||
TempoChange(tempo=default_tempo, time=0))
|
||||
default_ticks_per_beat = 480
|
||||
default_in_beat_ticks = 480 // self.in_beat_resolution
|
||||
cur_pos = 0
|
||||
bar_pos = 0
|
||||
cur_bar_resol = 0
|
||||
beat_pos = 0
|
||||
instr_notes_dict = defaultdict(list)
|
||||
generated_output = generated_output.squeeze(0)
|
||||
for i in range(len(generated_output)):
|
||||
token_with_7infos = []
|
||||
for kidx, key in enumerate(feature_keys):
|
||||
token_with_7infos.append(idx2event[key][generated_output[i][kidx].item()])
|
||||
# type token
|
||||
if 'time_signature' in token_with_7infos[feature2idx['type']]:
|
||||
cur_num, cur_denom = token_with_7infos[feature2idx['type']].split('_')[-1].split('/')
|
||||
bar_pos += cur_bar_resol
|
||||
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
|
||||
cur_bar_resol = new_bar_resol
|
||||
midi_obj.time_signature_changes.append(
|
||||
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
|
||||
elif token_with_7infos[feature2idx['type']] == 'Metrical':
|
||||
if 'time_signature' in token_with_7infos[feature2idx['beat']]:
|
||||
cur_num, cur_denom = token_with_7infos[feature2idx['beat']].split('_')[-1].split('/')
|
||||
bar_pos += cur_bar_resol
|
||||
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
|
||||
cur_bar_resol = new_bar_resol
|
||||
midi_obj.time_signature_changes.append(
|
||||
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
|
||||
elif token_with_7infos[feature2idx['beat']] == 'Bar':
|
||||
bar_pos += cur_bar_resol
|
||||
elif 'Beat' in str(token_with_7infos[feature2idx['beat']]):
|
||||
beat_pos = int(token_with_7infos[feature2idx['beat']].split('_')[1])
|
||||
cur_pos = bar_pos + beat_pos * default_in_beat_ticks # ticks
|
||||
# chord and tempo
|
||||
midi_obj = self._update_chord_tempo(midi_obj, cur_pos, token_with_7infos, feature2idx)
|
||||
elif token_with_7infos[feature2idx['type']] == 'Note':
|
||||
# instrument token
|
||||
if len(feature2idx) == 8 or len(feature2idx) == 5:
|
||||
if token_with_7infos[feature2idx['instrument']] != 0 and token_with_7infos[feature2idx['instrument']] != 'CONTI':
|
||||
cur_instr = int(token_with_7infos[feature2idx['instrument']].split('_')[-1])
|
||||
else:
|
||||
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
|
||||
try:
|
||||
pitch = token_with_7infos[feature2idx['pitch']].split('_')[-1]
|
||||
duration = token_with_7infos[feature2idx['duration']].split('_')[-1]
|
||||
duration = int(duration) * default_in_beat_ticks
|
||||
if len(feature2idx) == 7 or len(feature2idx) == 8:
|
||||
velocity = token_with_7infos[feature2idx['velocity']].split('_')[-1]
|
||||
else:
|
||||
velocity = 80
|
||||
end = cur_pos + duration
|
||||
instr_notes_dict[cur_instr].append(
|
||||
Note(
|
||||
pitch=int(pitch),
|
||||
start=cur_pos,
|
||||
end=end,
|
||||
velocity=int(velocity))
|
||||
)
|
||||
except:
|
||||
continue
|
||||
else: # when new bar started without beat
|
||||
continue
|
||||
|
||||
# save midi
|
||||
for instr, notes in instr_notes_dict.items():
|
||||
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
|
||||
if instr == 114: is_drum = True
|
||||
else: is_drum = False
|
||||
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
|
||||
instr_track.notes = notes
|
||||
midi_obj.instruments.append(instr_track)
|
||||
|
||||
if isinstance(output_path, str) or isinstance(output_path, Path):
|
||||
output_path = str(output_path)
|
||||
output_music_dir = os.path.join(os.path.dirname(output_path), 'music')
|
||||
if not os.path.exists(output_music_dir):
|
||||
os.makedirs(output_music_dir)
|
||||
midi_obj.dump(output_path)
|
||||
save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
|
||||
save_wav_from_midi_fluidsynth(output_path, output_music_dir.replace('.mid', '.wav'), gain=self.gain)
|
||||
return midi_obj
|
||||
|
||||
class MidiDecoder4NB(MidiDecoder4REMI):
|
||||
def __init__(self, vocab, in_beat_resolution, dataset_name):
|
||||
super().__init__(vocab, in_beat_resolution, dataset_name)
|
||||
|
||||
def _update_additional_info(self, midi_obj, cur_pos, token_with_7infos, feature2idx):
|
||||
if len(feature2idx) == 7 or len(feature2idx) == 8:
|
||||
# chord
|
||||
if token_with_7infos[feature2idx['chord']] != 'CONTI' and token_with_7infos[feature2idx['chord']] != 0 and token_with_7infos[feature2idx['chord']] != 'Chord_N_N':
|
||||
midi_obj.markers.append(
|
||||
Marker(text=str(token_with_7infos[feature2idx['chord']]), time=cur_pos))
|
||||
# tempo
|
||||
if token_with_7infos[feature2idx['tempo']] != 'CONTI' and token_with_7infos[feature2idx['tempo']] != 0 and token_with_7infos[feature2idx['tempo']] != "Tempo_N_N":
|
||||
tempo = int(token_with_7infos[feature2idx['tempo']].split('_')[-1])
|
||||
midi_obj.tempo_changes.append(
|
||||
TempoChange(tempo=tempo, time=cur_pos))
|
||||
return midi_obj
|
||||
elif len(feature2idx) == 4 or len(feature2idx) == 5:
|
||||
return midi_obj
|
||||
|
||||
def __call__(self, generated_output, output_path=None):
|
||||
'''
|
||||
generated_output: tensor, batch x seq_len x num_types
|
||||
num_types includes: type, beat, chord, tempo, intrument, pitch, duration, velocity
|
||||
'''
|
||||
idx2event = self.vocab.idx2event
|
||||
feature_keys = self.vocab.feature_list
|
||||
feature2idx = {key: idx for idx, key in enumerate(feature_keys)}
|
||||
|
||||
midi_obj = miditoolkit.midi.parser.MidiFile()
|
||||
if len(feature2idx) == 4 or len(feature2idx) == 5:
|
||||
default_tempo = 95
|
||||
midi_obj.tempo_changes.append(
|
||||
TempoChange(tempo=default_tempo, time=0))
|
||||
default_ticks_per_beat = 480
|
||||
default_in_beat_ticks = 480 // self.in_beat_resolution
|
||||
cur_pos = 0
|
||||
bar_pos = 0
|
||||
cur_bar_resol = 0
|
||||
beat_pos = 0
|
||||
instr_notes_dict = defaultdict(list)
|
||||
generated_output = generated_output.squeeze(0)
|
||||
for i in range(len(generated_output)):
|
||||
token_with_7infos = []
|
||||
for kidx, key in enumerate(feature_keys):
|
||||
token_with_7infos.append(idx2event[key][generated_output[i][kidx].item()])
|
||||
# type token
|
||||
if token_with_7infos[feature2idx['type']] == 'Empty_Bar' or token_with_7infos[feature2idx['type']] == 'SNN':
|
||||
bar_pos += cur_bar_resol
|
||||
elif 'NNN' in token_with_7infos[feature2idx['type']]:
|
||||
cur_num, cur_denom = token_with_7infos[feature2idx['type']].split('_')[-1].split('/')
|
||||
bar_pos += cur_bar_resol
|
||||
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
|
||||
cur_bar_resol = new_bar_resol
|
||||
midi_obj.time_signature_changes.append(
|
||||
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
|
||||
# instrument token
|
||||
if len(feature2idx) == 8 or len(feature2idx) == 5:
|
||||
if token_with_7infos[feature2idx['instrument']] != 0 and token_with_7infos[feature2idx['instrument']] != 'CONTI':
|
||||
cur_instr = int(token_with_7infos[feature2idx['instrument']].split('_')[-1])
|
||||
else:
|
||||
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
|
||||
if 'Beat' in str(token_with_7infos[feature2idx['beat']]) or 'CONTI' in str(token_with_7infos[feature2idx['beat']]):
|
||||
if 'Beat' in str(token_with_7infos[feature2idx['beat']]): # when beat is not CONTI beat is updated
|
||||
beat_pos = int(token_with_7infos[feature2idx['beat']].split('_')[1])
|
||||
cur_pos = bar_pos + beat_pos * default_in_beat_ticks # ticks
|
||||
# update chord and tempo
|
||||
midi_obj = self._update_additional_info(midi_obj, cur_pos, token_with_7infos, feature2idx)
|
||||
# note
|
||||
try:
|
||||
pitch = token_with_7infos[feature2idx['pitch']].split('_')[-1]
|
||||
duration = token_with_7infos[feature2idx['duration']].split('_')[-1] # duration between 1~192
|
||||
duration = int(duration) * default_in_beat_ticks
|
||||
if len(feature2idx) == 7 or len(feature2idx) == 8:
|
||||
velocity = token_with_7infos[feature2idx['velocity']].split('_')[-1]
|
||||
else:
|
||||
velocity = 90
|
||||
end = cur_pos + duration
|
||||
instr_notes_dict[cur_instr].append(
|
||||
Note(
|
||||
pitch=int(pitch),
|
||||
start=cur_pos,
|
||||
end=end,
|
||||
velocity=int(velocity))
|
||||
)
|
||||
except:
|
||||
continue
|
||||
else: # when new bar started without beat
|
||||
continue
|
||||
|
||||
# save midi
|
||||
for instr, notes in instr_notes_dict.items():
|
||||
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
|
||||
if instr == 114: is_drum = True
|
||||
else: is_drum = False
|
||||
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
|
||||
instr_track.notes = notes
|
||||
midi_obj.instruments.append(instr_track)
|
||||
|
||||
if isinstance(output_path, str) or isinstance(output_path, Path):
|
||||
output_path = str(output_path)
|
||||
music_path = os.path.join(os.path.dirname(output_path), 'music')
|
||||
prompt_music_path = os.path.join(os.path.dirname(output_path), 'prompt_music')
|
||||
if not os.path.exists(music_path):
|
||||
os.makedirs(music_path)
|
||||
if not os.path.exists(prompt_music_path):
|
||||
os.makedirs(prompt_music_path)
|
||||
# if not contain 'prompt' in output_path, save prompt music
|
||||
if 'prompt' in output_path:
|
||||
music_path = os.path.join(prompt_music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||||
else:
|
||||
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||||
midi_obj.dump(output_path)
|
||||
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
|
||||
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
|
||||
return midi_obj
|
||||
208
Amadeus/symbolic_encoding/metric_utils.py
Normal file
208
Amadeus/symbolic_encoding/metric_utils.py
Normal file
@ -0,0 +1,208 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from collections import Counter
|
||||
|
||||
# TODO: refactor hard coded values
|
||||
def check_syntax_errors_in_inference_for_nb(generated_output, feature_list):
|
||||
generated_output = generated_output.squeeze(0)
|
||||
type_idx = feature_list.index('type')
|
||||
beat_idx = feature_list.index('beat')
|
||||
type_beat_list = []
|
||||
for token in generated_output:
|
||||
type_beat_list.append((token[type_idx].item(), token[beat_idx].item())) # type, beat
|
||||
|
||||
last_note = 1
|
||||
beat_type_unmatched_error_list = []
|
||||
num_unmatched_errors = 0
|
||||
beat_backwards_error_list = []
|
||||
num_backwards_errors = 0
|
||||
for type_beat in type_beat_list:
|
||||
if type_beat[0] == 4: # same bar, new beat
|
||||
if type_beat[1] == 0 or type_beat[1] == 1:
|
||||
num_unmatched_errors += 1
|
||||
beat_type_unmatched_error_list.append(type_beat)
|
||||
if type_beat[1] <= last_note:
|
||||
num_backwards_errors += 1
|
||||
beat_backwards_error_list.append([last_note, type_beat])
|
||||
else:
|
||||
last_note = type_beat[1] # update last note
|
||||
elif type_beat[0] >= 5: # new bar, new beat
|
||||
if type_beat[1] == 0:
|
||||
num_unmatched_errors += 1
|
||||
beat_type_unmatched_error_list.append(type_beat)
|
||||
last_note = 1
|
||||
unmatched_error_rate = num_unmatched_errors / len(type_beat_list)
|
||||
backwards_error_rate = num_backwards_errors / len(type_beat_list)
|
||||
type_beat_errors_dict = {'beat_type_unmatched_error': unmatched_error_rate, 'beat_backwards_error': backwards_error_rate}
|
||||
return type_beat_errors_dict
|
||||
|
||||
def check_syntax_errors_in_inference_for_cp(generated_output, feature_list):
|
||||
generated_output = generated_output.squeeze(0)
|
||||
type_idx = feature_list.index('type')
|
||||
beat_idx = feature_list.index('beat')
|
||||
pitch_idx = feature_list.index('pitch')
|
||||
duration_idx = feature_list.index('duration')
|
||||
last_note = 1
|
||||
beat_type_unmatched_error_list = []
|
||||
num_unmatched_errors = 0
|
||||
beat_backwards_error_list = []
|
||||
num_backwards_errors = 0
|
||||
for token in generated_output:
|
||||
if token[type_idx].item() == 2: # Metrical
|
||||
if token[pitch_idx].item() != 0 or token[duration_idx].item() != 0:
|
||||
num_unmatched_errors += 1
|
||||
beat_type_unmatched_error_list.append(token)
|
||||
if token[beat_idx].item() == 1: # new bar
|
||||
last_note = 1 # last note will be updated in the next token
|
||||
elif token[beat_idx].item() != 0 and token[beat_idx].item() <= last_note:
|
||||
num_backwards_errors += 1
|
||||
last_note = token[beat_idx].item() # update last note
|
||||
beat_backwards_error_list.append([last_note, token])
|
||||
else:
|
||||
last_note = token[beat_idx].item() # update last note
|
||||
if token[type_idx].item() == 3: # Note
|
||||
if token[beat_idx].item() != 0:
|
||||
num_unmatched_errors += 1
|
||||
beat_type_unmatched_error_list.append(token)
|
||||
unmatched_error_rate = num_unmatched_errors / len(generated_output)
|
||||
backwards_error_rate = num_backwards_errors / len(generated_output)
|
||||
type_beat_errors_dict = {'beat_type_unmatched_error': unmatched_error_rate, 'beat_backwards_error': backwards_error_rate}
|
||||
return type_beat_errors_dict
|
||||
|
||||
def check_syntax_errors_in_inference_for_remi(generated_output, vocab):
|
||||
generated_output = generated_output.squeeze(0)
|
||||
# to check duration errors
|
||||
beat_mask = vocab.total_mask['beat'].to(generated_output.device)
|
||||
beat_mask_for_target = beat_mask[generated_output]
|
||||
beat_target = generated_output * beat_mask_for_target
|
||||
bar_mask = vocab.total_mask['type'].to(generated_output.device)
|
||||
bar_mask_for_target = bar_mask[generated_output]
|
||||
bar_target = (generated_output+1) * bar_mask_for_target # as bar token in 0 in remi vocab, we add 1 to bar token
|
||||
target = beat_target + bar_target
|
||||
target = target[target!=0]
|
||||
# collect beats in between bars(idx=1)
|
||||
num_backwards_errors = 0
|
||||
collected_beats = []
|
||||
total_beats = 0
|
||||
for token in target:
|
||||
if token == 1 or 3 <= token <= 26: # Bar_None, or Bar_time_signature
|
||||
collected_beats_tensor = torch.tensor(collected_beats)
|
||||
diff = torch.diff(collected_beats_tensor)
|
||||
num_error_beats = torch.where(diff<=0)[0].shape[0]
|
||||
num_backwards_errors += num_error_beats
|
||||
collected_beats = []
|
||||
else:
|
||||
collected_beats.append(token.item())
|
||||
total_beats += 1
|
||||
if total_beats != 0:
|
||||
backwards_error_rate = num_backwards_errors / total_beats
|
||||
else:
|
||||
backwards_error_rate = 0
|
||||
# print(f"error rate in beat backwards: {backwards_error_rate}")
|
||||
return {'beat_backwards_error': backwards_error_rate}
|
||||
|
||||
def type_beat_errors_in_validation_nb(beat_prob, answer_type, input_beat, mask):
|
||||
bool_mask = mask.bool().flatten() # (b*t)
|
||||
pred_beat_idx = torch.argmax(beat_prob, dim=-1).flatten() # (b*t)
|
||||
valid_pred_beat_idx = pred_beat_idx[bool_mask] # valid beat_idx
|
||||
answer_type = answer_type.flatten() # (b*t)
|
||||
valid_type_input = answer_type[bool_mask] # valid answer_type
|
||||
type_beat_list = []
|
||||
for i in range(len(valid_pred_beat_idx)):
|
||||
type_beat_list.append((valid_type_input[i].item(), valid_pred_beat_idx[i].item())) # type, beat
|
||||
input_beat = input_beat.flatten()
|
||||
valid_input_beat = input_beat[bool_mask]
|
||||
|
||||
last_note = 1
|
||||
num_unmatched_errors = 0
|
||||
num_backwards_errors = 0
|
||||
for type_beat, input_beat_idx in zip(type_beat_list, valid_input_beat):
|
||||
# update last note
|
||||
if input_beat_idx.item() >= 1: # beat
|
||||
last_note = input_beat_idx.item()
|
||||
if type_beat[0] == 4: # same bar, new beat
|
||||
if type_beat[1] == 0 or type_beat[1] == 1:
|
||||
num_unmatched_errors += 1
|
||||
if type_beat[1] <= last_note:
|
||||
num_backwards_errors += 1
|
||||
elif type_beat[0] >= 5: # new bar, new beat
|
||||
if type_beat[1] == 0:
|
||||
num_unmatched_errors += 1
|
||||
return len(type_beat_list), num_unmatched_errors, num_backwards_errors
|
||||
|
||||
def type_beat_errors_in_validation_cp(beat_prob, answer_type, input_beat, mask):
|
||||
bool_mask = mask.bool().flatten() # (b*t)
|
||||
beat_idx = torch.argmax(beat_prob, dim=-1).flatten() # (b*t)
|
||||
valid_beat_idx = beat_idx[bool_mask] # valid beat_idx
|
||||
answer_type = answer_type.flatten() # (b*t)
|
||||
valid_type_input = answer_type[bool_mask] # valid answer_type
|
||||
type_beat_list = []
|
||||
for i in range(len(valid_beat_idx)):
|
||||
type_beat_list.append((valid_type_input[i].item(), valid_beat_idx[i].item())) # type, beat
|
||||
input_beat = input_beat.flatten()
|
||||
valid_input_beat = input_beat[bool_mask]
|
||||
|
||||
last_note = 1
|
||||
num_unmatched_errors = 0
|
||||
num_backwards_errors = 0
|
||||
for type_beat, input_beat_idx in zip(type_beat_list, valid_input_beat):
|
||||
# update last note
|
||||
if input_beat_idx.item() == 1: # bar
|
||||
last_note = 1
|
||||
elif input_beat_idx.item() >= 2: # new beat
|
||||
last_note = input_beat_idx.item()
|
||||
# check errors
|
||||
if type_beat[0] == 2: # Metrical
|
||||
if type_beat[1] == 0: # ignore
|
||||
num_unmatched_errors += 1
|
||||
elif type_beat[1] >= 2: # new beat
|
||||
if type_beat[1] <= last_note:
|
||||
num_backwards_errors += 1
|
||||
elif type_beat[0] == 3: # Note
|
||||
if type_beat[1] != 0:
|
||||
num_unmatched_errors += 1
|
||||
return len(type_beat_list), num_unmatched_errors, num_backwards_errors
|
||||
|
||||
def get_beat_difference_metric(prob_dict, arranged_prob_dict, mask):
|
||||
orign_beat_prob = prob_dict['beat'] # b x t x vocab_size
|
||||
arranged_beat_prob = arranged_prob_dict['beat'] # b x t x vocab_size
|
||||
|
||||
# calculate similarity between original beat prob and arranged beat prob
|
||||
origin_beat_token = torch.argmax(orign_beat_prob, dim=-1) * mask # b x t
|
||||
arranged_beat_token = torch.argmax(arranged_beat_prob, dim=-1) * mask # b x t
|
||||
num_same_beat = torch.sum(origin_beat_token == arranged_beat_token) - torch.sum(mask==0)
|
||||
num_beat = torch.sum(mask==1)
|
||||
beat_sim = (num_same_beat / num_beat).item() # scalar
|
||||
|
||||
# apply mask, shape of mask: b x t
|
||||
orign_beat_prob = orign_beat_prob * mask.unsqueeze(-1) # b x t x vocab_size
|
||||
arranged_beat_prob = arranged_beat_prob * mask.unsqueeze(-1)
|
||||
|
||||
# calculate cosine similarity between original beat prob and arranged beat prob
|
||||
orign_beat_prob = orign_beat_prob.flatten(0,1) # (b*t) x vocab_size
|
||||
arranged_beat_prob = arranged_beat_prob.flatten(0,1) # (b*t) x vocab_size
|
||||
cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
|
||||
beat_cos_sim = cos(orign_beat_prob, arranged_beat_prob) # (b*t)
|
||||
# exclude invalid tokens, zero padding tokens
|
||||
beat_cos_sim = beat_cos_sim[mask.flatten().bool()] # num_valid_tokens
|
||||
beat_cos_sim = torch.mean(beat_cos_sim).item() # scalar
|
||||
return {'beat_cos_sim': beat_cos_sim, 'beat_sim': beat_sim}
|
||||
|
||||
def get_gini_coefficient(generated_output):
|
||||
if len(generated_output.shape) == 3:
|
||||
generated_output = generated_output.squeeze(0).tolist()
|
||||
gen_list = [tuple(x) for x in generated_output]
|
||||
else:
|
||||
gen_list = generated_output.squeeze(0).tolist()
|
||||
counts = Counter(gen_list).values()
|
||||
sorted_counts = sorted(counts)
|
||||
n = len(sorted_counts)
|
||||
cumulative_counts = np.cumsum(sorted_counts)
|
||||
cumulative_proportion = cumulative_counts / cumulative_counts[-1]
|
||||
|
||||
lorenz_area = sum(cumulative_proportion[:-1]) / n # Exclude the last element
|
||||
equality_area = 0.5 # The area under line of perfect equality
|
||||
|
||||
gini = (equality_area - lorenz_area) / equality_area
|
||||
return gini
|
||||
78
Amadeus/symbolic_encoding/midi2audio.py
Normal file
78
Amadeus/symbolic_encoding/midi2audio.py
Normal file
@ -0,0 +1,78 @@
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
from pydub import AudioSegment
|
||||
|
||||
'''
|
||||
This file is a modified version of midi2audio.py from https://github.com/bzamecnik/midi2audio
|
||||
Author: Bohumír Zámečník (@bzamecnik)
|
||||
License: MIT, see the LICENSE file
|
||||
'''
|
||||
|
||||
__all__ = ['FluidSynth']
|
||||
|
||||
DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2'
|
||||
DEFAULT_SAMPLE_RATE = 48000
|
||||
DEFAULT_GAIN = 0.05
|
||||
# DEFAULT_SOUND_FONT = "/data2/suhongju/research/music-generation/sound_file/Advent GM 7.sf2"
|
||||
# DEFAULT_SOUND_FONT = '~/.fluidsynth/default_sound_font.sf2'
|
||||
# DEFAULT_SAMPLE_RATE = 16000
|
||||
# DEFAULT_GAIN = 0.20
|
||||
|
||||
class FluidSynth():
|
||||
def __init__(self, sound_font=DEFAULT_SOUND_FONT, sample_rate=DEFAULT_SAMPLE_RATE, gain=DEFAULT_GAIN):
|
||||
self.sample_rate = sample_rate
|
||||
self.sound_font = os.path.expanduser(sound_font)
|
||||
self.gain = gain
|
||||
|
||||
def midi_to_audio(self, midi_file: str, audio_file: str, verbose=True):
|
||||
if verbose:
|
||||
stdout = None
|
||||
else:
|
||||
stdout = subprocess.DEVNULL
|
||||
|
||||
# Convert MIDI to WAV
|
||||
subprocess.call(
|
||||
['fluidsynth', '-ni', '-g', str(self.gain), self.sound_font, midi_file, '-F', audio_file, '-r', str(self.sample_rate)],
|
||||
stdout=stdout
|
||||
)
|
||||
|
||||
# Convert WAV to MP3
|
||||
# mp3_path = audio_file.replace('.wav', '.mp3')
|
||||
# AudioSegment.from_wav(audio_file).export(mp3_path, format="mp3")
|
||||
|
||||
# # Delete the temporary WAV file
|
||||
# os.remove(audio_file)
|
||||
|
||||
def play_midi(self, midi_file):
|
||||
subprocess.call(['fluidsynth', '-i', '-g', str(self.gain), self.sound_font, midi_file, '-r', str(self.sample_rate)])
|
||||
|
||||
def parse_args(allow_synth=True):
|
||||
parser = argparse.ArgumentParser(description='Convert MIDI to audio via FluidSynth')
|
||||
parser.add_argument('midi_file', metavar='MIDI', type=str)
|
||||
if allow_synth:
|
||||
parser.add_argument('audio_file', metavar='AUDIO', type=str, nargs='?')
|
||||
parser.add_argument('-s', '--sound-font', type=str,
|
||||
default=DEFAULT_SOUND_FONT,
|
||||
help='path to a SF2 sound font (default: %s)' % DEFAULT_SOUND_FONT)
|
||||
parser.add_argument('-r', '--sample-rate', type=int, nargs='?',
|
||||
default=DEFAULT_SAMPLE_RATE,
|
||||
help='sample rate in Hz (default: %s)' % DEFAULT_SAMPLE_RATE)
|
||||
return parser.parse_args()
|
||||
|
||||
def main(allow_synth=True):
|
||||
args = parse_args(allow_synth)
|
||||
fs = FluidSynth(args.sound_font, args.sample_rate)
|
||||
if allow_synth and args.audio_file:
|
||||
fs.midi_to_audio(args.midi_file, args.audio_file)
|
||||
else:
|
||||
fs.play_midi(args.midi_file)
|
||||
|
||||
def main_play():
|
||||
"""
|
||||
A method for the `midiplay` entry point. It omits the audio file from args.
|
||||
"""
|
||||
main(allow_synth=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user