466 lines
19 KiB
Python
466 lines
19 KiB
Python
import re
|
||
import os, sys
|
||
from pathlib import Path
|
||
|
||
import matplotlib.pyplot as plt
|
||
from collections import defaultdict
|
||
|
||
import torch
|
||
from music21 import converter
|
||
import muspy
|
||
import miditoolkit
|
||
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature
|
||
from symusic import Score
|
||
from miditok import Octuple, TokenizerConfig
|
||
|
||
from .midi2audio import FluidSynth
|
||
from data_representation.constants import PROGRAM_INSTRUMENT_MAP
|
||
|
||
class MuteWarn:
|
||
def __enter__(self):
|
||
self._init_stdout = sys.stdout
|
||
sys.stdout = open(os.devnull, "w")
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
sys.stdout.close()
|
||
sys.stdout = self._init_stdout
|
||
|
||
def save_score_image_from_midi(midi_fn, file_name):
|
||
assert isinstance(midi_fn, str)
|
||
with MuteWarn():
|
||
convert = converter.parse(midi_fn)
|
||
convert.write('musicxml.png', fp=file_name)
|
||
|
||
def save_pianoroll_image_from_midi(midi_fn, file_name):
|
||
assert isinstance(midi_fn, str)
|
||
midi_obj_muspy = muspy.read_midi(midi_fn)
|
||
midi_obj_muspy.show_pianoroll(track_label='program', preset='frame')
|
||
plt.gcf().set_size_inches(20, 10)
|
||
plt.savefig(file_name)
|
||
plt.close()
|
||
|
||
def save_wav_from_midi(midi_fn, file_name, qpm=80):
|
||
assert isinstance(midi_fn, str)
|
||
with MuteWarn():
|
||
music = muspy.read_midi(midi_fn)
|
||
music.tempos[0].qpm = qpm
|
||
music.write_audio(file_name, rate=44100, gain=3)
|
||
|
||
def save_wav_from_midi_fluidsynth(midi_fn, file_name, gain=3):
|
||
assert isinstance(midi_fn, str)
|
||
fs = FluidSynth(gain=gain)
|
||
fs.midi_to_audio(midi_fn, file_name)
|
||
|
||
class MidiDecoder4REMI:
|
||
def __init__(
|
||
self,
|
||
vocab,
|
||
in_beat_resolution,
|
||
dataset_name
|
||
):
|
||
self.vocab = vocab
|
||
self.in_beat_resolution = in_beat_resolution
|
||
self.dataset_name = dataset_name
|
||
if dataset_name == 'SymphonyMIDI':
|
||
self.gain = 0.7
|
||
elif dataset_name == 'SOD' or dataset_name == 'LakhClean':
|
||
self.gain = 1.1
|
||
elif dataset_name == 'Pop1k7' or dataset_name == 'Pop909':
|
||
self.gain = 2.5
|
||
else:
|
||
self.gain = 1.5
|
||
|
||
def __call__(self, generated_output, output_path=None):
|
||
'''
|
||
generated_output: list of tensor, the tensor
|
||
'''
|
||
idx2event = self.vocab.idx2event
|
||
if generated_output.dim() == 2:
|
||
generated_output = generated_output.squeeze(0)
|
||
events = [idx2event[token.item()] for token in generated_output]
|
||
|
||
midi_obj = miditoolkit.midi.parser.MidiFile()
|
||
if 'tempo' not in idx2event.keys():
|
||
default_tempo = 95
|
||
midi_obj.tempo_changes.append(
|
||
TempoChange(tempo=default_tempo, time=0))
|
||
default_ticks_per_beat = 480
|
||
default_in_beat_ticks = 480 // self.in_beat_resolution
|
||
cur_pos = 0
|
||
bar_pos = 0
|
||
cur_bar_resol = 0
|
||
beat_pos = 0
|
||
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
|
||
instr_notes_dict = defaultdict(list)
|
||
for i in range(len(events)-2):
|
||
cur_event = events[i]
|
||
# print(cur_event)
|
||
name = cur_event.split('_')[0]
|
||
attr = cur_event.split('_')
|
||
if name == 'Bar':
|
||
bar_pos += cur_bar_resol
|
||
if 'time' in cur_event:
|
||
cur_num, cur_denom = attr[-1].split('/')
|
||
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
|
||
cur_bar_resol = new_bar_resol
|
||
midi_obj.time_signature_changes.append(
|
||
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
|
||
elif name == 'Beat':
|
||
beat_pos = int(attr[1])
|
||
cur_pos = bar_pos + beat_pos * default_in_beat_ticks
|
||
elif name == 'Chord':
|
||
chord_text = attr[1] + '_' + attr[2]
|
||
midi_obj.markers.append(Marker(text=chord_text, time=cur_pos))
|
||
elif name == 'Tempo':
|
||
midi_obj.tempo_changes.append(
|
||
TempoChange(tempo=int(attr[1]), time=cur_pos))
|
||
elif name == 'Instrument':
|
||
cur_instr = int(attr[1])
|
||
else:
|
||
if len(self.vocab.feature_list) == 7 or len(self.vocab.feature_list) == 8:
|
||
if 'Note_Pitch' in events[i] and \
|
||
'Note_Duration' in events[i+1] and \
|
||
'Note_Velocity' in events[i+2]:
|
||
pitch = int(events[i].split('_')[-1])
|
||
duration = int(events[i+1].split('_')[-1])
|
||
duration = duration * default_in_beat_ticks
|
||
end = cur_pos + duration
|
||
velocity = int(events[i+2].split('_')[-1])
|
||
instr_notes_dict[cur_instr].append(
|
||
Note(
|
||
pitch=pitch,
|
||
start=cur_pos,
|
||
end=end,
|
||
velocity=velocity))
|
||
elif len(self.vocab.feature_list) == 4 or len(self.vocab.feature_list) == 5:
|
||
if 'Note_Pitch' in events[i] and \
|
||
'Note_Duration' in events[i+1]:
|
||
pitch = int(events[i].split('_')[-1])
|
||
duration = int(events[i+1].split('_')[-1])
|
||
duration = duration * default_in_beat_ticks
|
||
end = cur_pos + duration
|
||
velocity = 90
|
||
instr_notes_dict[cur_instr].append(
|
||
Note(
|
||
pitch=pitch,
|
||
start=cur_pos,
|
||
end=end,
|
||
velocity=velocity))
|
||
|
||
# save midi
|
||
for instr, notes in instr_notes_dict.items():
|
||
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
|
||
if instr == 114: is_drum = True
|
||
else: is_drum = False
|
||
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
|
||
instr_track.notes = notes
|
||
midi_obj.instruments.append(instr_track)
|
||
|
||
if isinstance(output_path, str) or isinstance(output_path, Path):
|
||
output_path = str(output_path)
|
||
# make subdir
|
||
music_path = os.path.join(os.path.dirname(output_path), 'music')
|
||
prompt_music_path = os.path.join(os.path.dirname(output_path), 'prompt_music')
|
||
if not os.path.exists(music_path):
|
||
os.makedirs(music_path)
|
||
if not os.path.exists(prompt_music_path):
|
||
os.makedirs(prompt_music_path)
|
||
# if not contain 'prompt' in output_path, save prompt music
|
||
if 'prompt' in output_path:
|
||
music_path = os.path.join(prompt_music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||
else:
|
||
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||
|
||
midi_obj.dump(output_path)
|
||
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
|
||
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
|
||
return midi_obj
|
||
|
||
class MidiDecoder4CP(MidiDecoder4REMI):
|
||
def __init__(self, vocab, in_beat_resolution, dataset_name):
|
||
super().__init__(vocab, in_beat_resolution, dataset_name)
|
||
|
||
def _update_chord_tempo(self, midi_obj, cur_pos, token_with_7infos, feature2idx):
|
||
if len(feature2idx) == 7 or len(feature2idx) == 8:
|
||
# chord
|
||
if token_with_7infos[feature2idx['chord']] != 'CONTI' and token_with_7infos[feature2idx['chord']] != 0:
|
||
midi_obj.markers.append(
|
||
Marker(text=str(token_with_7infos[feature2idx['chord']]), time=cur_pos))
|
||
# tempo
|
||
if token_with_7infos[feature2idx['tempo']] != 'CONTI' and token_with_7infos[feature2idx['tempo']] != 0 and token_with_7infos[feature2idx['tempo']] != "Tempo_N_N":
|
||
tempo = int(token_with_7infos[feature2idx['tempo']].split('_')[-1])
|
||
midi_obj.tempo_changes.append(
|
||
TempoChange(tempo=tempo, time=cur_pos))
|
||
return midi_obj
|
||
elif len(feature2idx) == 4 or len(feature2idx) == 5:
|
||
return midi_obj
|
||
|
||
def __call__(self, generated_output, output_path=None):
|
||
'''
|
||
generated_output: tensor, batch x seq_len x num_types
|
||
num_types includes: type, tempo, chord,'beat, pitch, duration, velocity
|
||
'''
|
||
idx2event = self.vocab.idx2event
|
||
feature_keys = self.vocab.feature_list
|
||
feature2idx = {key: idx for idx, key in enumerate(feature_keys)}
|
||
|
||
midi_obj = miditoolkit.midi.parser.MidiFile()
|
||
if len(feature2idx) == 4 or len(feature2idx) == 5:
|
||
default_tempo = 95
|
||
midi_obj.tempo_changes.append(
|
||
TempoChange(tempo=default_tempo, time=0))
|
||
default_ticks_per_beat = 480
|
||
default_in_beat_ticks = 480 // self.in_beat_resolution
|
||
cur_pos = 0
|
||
bar_pos = 0
|
||
cur_bar_resol = 0
|
||
beat_pos = 0
|
||
instr_notes_dict = defaultdict(list)
|
||
generated_output = generated_output.squeeze(0)
|
||
for i in range(len(generated_output)):
|
||
token_with_7infos = []
|
||
for kidx, key in enumerate(feature_keys):
|
||
token_with_7infos.append(idx2event[key][generated_output[i][kidx].item()])
|
||
# type token
|
||
if 'time_signature' in token_with_7infos[feature2idx['type']]:
|
||
cur_num, cur_denom = token_with_7infos[feature2idx['type']].split('_')[-1].split('/')
|
||
bar_pos += cur_bar_resol
|
||
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
|
||
cur_bar_resol = new_bar_resol
|
||
midi_obj.time_signature_changes.append(
|
||
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
|
||
elif token_with_7infos[feature2idx['type']] == 'Metrical':
|
||
if 'time_signature' in token_with_7infos[feature2idx['beat']]:
|
||
cur_num, cur_denom = token_with_7infos[feature2idx['beat']].split('_')[-1].split('/')
|
||
bar_pos += cur_bar_resol
|
||
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
|
||
cur_bar_resol = new_bar_resol
|
||
midi_obj.time_signature_changes.append(
|
||
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
|
||
elif token_with_7infos[feature2idx['beat']] == 'Bar':
|
||
bar_pos += cur_bar_resol
|
||
elif 'Beat' in str(token_with_7infos[feature2idx['beat']]):
|
||
beat_pos = int(token_with_7infos[feature2idx['beat']].split('_')[1])
|
||
cur_pos = bar_pos + beat_pos * default_in_beat_ticks # ticks
|
||
# chord and tempo
|
||
midi_obj = self._update_chord_tempo(midi_obj, cur_pos, token_with_7infos, feature2idx)
|
||
elif token_with_7infos[feature2idx['type']] == 'Note':
|
||
# instrument token
|
||
if len(feature2idx) == 8 or len(feature2idx) == 5:
|
||
if token_with_7infos[feature2idx['instrument']] != 0 and token_with_7infos[feature2idx['instrument']] != 'CONTI':
|
||
cur_instr = int(token_with_7infos[feature2idx['instrument']].split('_')[-1])
|
||
else:
|
||
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
|
||
try:
|
||
pitch = token_with_7infos[feature2idx['pitch']].split('_')[-1]
|
||
duration = token_with_7infos[feature2idx['duration']].split('_')[-1]
|
||
duration = int(duration) * default_in_beat_ticks
|
||
if len(feature2idx) == 7 or len(feature2idx) == 8:
|
||
velocity = token_with_7infos[feature2idx['velocity']].split('_')[-1]
|
||
else:
|
||
velocity = 80
|
||
end = cur_pos + duration
|
||
instr_notes_dict[cur_instr].append(
|
||
Note(
|
||
pitch=int(pitch),
|
||
start=cur_pos,
|
||
end=end,
|
||
velocity=int(velocity))
|
||
)
|
||
except:
|
||
continue
|
||
else: # when new bar started without beat
|
||
continue
|
||
|
||
# save midi
|
||
for instr, notes in instr_notes_dict.items():
|
||
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
|
||
if instr == 114: is_drum = True
|
||
else: is_drum = False
|
||
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
|
||
instr_track.notes = notes
|
||
midi_obj.instruments.append(instr_track)
|
||
|
||
if isinstance(output_path, str) or isinstance(output_path, Path):
|
||
output_path = str(output_path)
|
||
output_music_dir = os.path.join(os.path.dirname(output_path), 'music')
|
||
if not os.path.exists(output_music_dir):
|
||
os.makedirs(output_music_dir)
|
||
midi_obj.dump(output_path)
|
||
save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
|
||
save_wav_from_midi_fluidsynth(output_path, output_music_dir.replace('.mid', '.wav'), gain=self.gain)
|
||
return midi_obj
|
||
|
||
class MidiDecoder4NB(MidiDecoder4REMI):
|
||
def __init__(self, vocab, in_beat_resolution, dataset_name):
|
||
super().__init__(vocab, in_beat_resolution, dataset_name)
|
||
|
||
def _update_additional_info(self, midi_obj, cur_pos, token_with_7infos, feature2idx):
|
||
if len(feature2idx) == 7 or len(feature2idx) == 8:
|
||
# chord
|
||
if token_with_7infos[feature2idx['chord']] != 'CONTI' and token_with_7infos[feature2idx['chord']] != 0 and token_with_7infos[feature2idx['chord']] != 'Chord_N_N':
|
||
midi_obj.markers.append(
|
||
Marker(text=str(token_with_7infos[feature2idx['chord']]), time=cur_pos))
|
||
# tempo
|
||
if token_with_7infos[feature2idx['tempo']] != 'CONTI' and token_with_7infos[feature2idx['tempo']] != 0 and token_with_7infos[feature2idx['tempo']] != "Tempo_N_N":
|
||
tempo = int(token_with_7infos[feature2idx['tempo']].split('_')[-1])
|
||
midi_obj.tempo_changes.append(
|
||
TempoChange(tempo=tempo, time=cur_pos))
|
||
return midi_obj
|
||
elif len(feature2idx) == 4 or len(feature2idx) == 5:
|
||
return midi_obj
|
||
|
||
def __call__(self, generated_output, output_path=None):
|
||
'''
|
||
generated_output: tensor, batch x seq_len x num_types
|
||
num_types includes: type, beat, chord, tempo, intrument, pitch, duration, velocity
|
||
'''
|
||
idx2event = self.vocab.idx2event
|
||
feature_keys = self.vocab.feature_list
|
||
feature2idx = {key: idx for idx, key in enumerate(feature_keys)}
|
||
|
||
midi_obj = miditoolkit.midi.parser.MidiFile()
|
||
if len(feature2idx) == 4 or len(feature2idx) == 5:
|
||
default_tempo = 95
|
||
midi_obj.tempo_changes.append(
|
||
TempoChange(tempo=default_tempo, time=0))
|
||
default_ticks_per_beat = 480
|
||
default_in_beat_ticks = 480 // self.in_beat_resolution
|
||
cur_pos = 0
|
||
bar_pos = 0
|
||
cur_bar_resol = 0
|
||
beat_pos = 0
|
||
instr_notes_dict = defaultdict(list)
|
||
generated_output = generated_output.squeeze(0)
|
||
for i in range(len(generated_output)):
|
||
token_with_7infos = []
|
||
for kidx, key in enumerate(feature_keys):
|
||
token_with_7infos.append(idx2event[key][generated_output[i][kidx].item()])
|
||
# type token
|
||
if token_with_7infos[feature2idx['type']] == 'Empty_Bar' or token_with_7infos[feature2idx['type']] == 'SNN':
|
||
bar_pos += cur_bar_resol
|
||
elif 'NNN' in token_with_7infos[feature2idx['type']]:
|
||
cur_num, cur_denom = token_with_7infos[feature2idx['type']].split('_')[-1].split('/')
|
||
bar_pos += cur_bar_resol
|
||
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
|
||
cur_bar_resol = new_bar_resol
|
||
midi_obj.time_signature_changes.append(
|
||
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
|
||
# instrument token
|
||
if len(feature2idx) == 8 or len(feature2idx) == 5:
|
||
if token_with_7infos[feature2idx['instrument']] != 0 and token_with_7infos[feature2idx['instrument']] != 'CONTI':
|
||
cur_instr = int(token_with_7infos[feature2idx['instrument']].split('_')[-1])
|
||
else:
|
||
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
|
||
if 'Beat' in str(token_with_7infos[feature2idx['beat']]) or 'CONTI' in str(token_with_7infos[feature2idx['beat']]):
|
||
if 'Beat' in str(token_with_7infos[feature2idx['beat']]): # when beat is not CONTI beat is updated
|
||
beat_pos = int(token_with_7infos[feature2idx['beat']].split('_')[1])
|
||
cur_pos = bar_pos + beat_pos * default_in_beat_ticks # ticks
|
||
# update chord and tempo
|
||
midi_obj = self._update_additional_info(midi_obj, cur_pos, token_with_7infos, feature2idx)
|
||
# note
|
||
try:
|
||
pitch = token_with_7infos[feature2idx['pitch']].split('_')[-1]
|
||
duration = token_with_7infos[feature2idx['duration']].split('_')[-1] # duration between 1~192
|
||
duration = int(duration) * default_in_beat_ticks
|
||
if len(feature2idx) == 7 or len(feature2idx) == 8:
|
||
velocity = token_with_7infos[feature2idx['velocity']].split('_')[-1]
|
||
else:
|
||
velocity = 90
|
||
end = cur_pos + duration
|
||
instr_notes_dict[cur_instr].append(
|
||
Note(
|
||
pitch=int(pitch),
|
||
start=cur_pos,
|
||
end=end,
|
||
velocity=int(velocity))
|
||
)
|
||
except:
|
||
continue
|
||
else: # when new bar started without beat
|
||
continue
|
||
|
||
# save midi
|
||
for instr, notes in instr_notes_dict.items():
|
||
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
|
||
if instr == 114: is_drum = True
|
||
else: is_drum = False
|
||
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
|
||
instr_track.notes = notes
|
||
midi_obj.instruments.append(instr_track)
|
||
|
||
if isinstance(output_path, str) or isinstance(output_path, Path):
|
||
output_path = str(output_path)
|
||
music_path = os.path.join(os.path.dirname(output_path), 'music')
|
||
prompt_music_path = os.path.join(os.path.dirname(output_path), 'prompt_music')
|
||
if not os.path.exists(music_path):
|
||
os.makedirs(music_path)
|
||
if not os.path.exists(prompt_music_path):
|
||
os.makedirs(prompt_music_path)
|
||
# if not contain 'prompt' in output_path, save prompt music
|
||
if 'prompt' in output_path:
|
||
music_path = os.path.join(prompt_music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||
else:
|
||
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||
midi_obj.dump(output_path)
|
||
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
|
||
# save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
|
||
return midi_obj
|
||
|
||
class MidiDecoder4Octuple(MidiDecoder4REMI):
|
||
def __init__(self, vocab, in_beat_resolution, dataset_name):
|
||
super().__init__(vocab, in_beat_resolution, dataset_name)
|
||
|
||
|
||
|
||
def remove_rows_with_exact_0_1_2_3(self, t: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
输入:
|
||
t: torch.Tensor, 形状 (1, N, M)
|
||
功能:
|
||
删除包含独立元素 0, 1, 2, 3 的子tensor行
|
||
返回:
|
||
torch.Tensor, 同样保持 batch 维度 (1, N_filtered, M)
|
||
"""
|
||
if t.dim() != 3:
|
||
raise ValueError("输入 tensor 必须是三维 (batch, seq_len, feature)")
|
||
|
||
# 构造一个 mask,True 表示该行不包含 0,1,2,3
|
||
exclude_vals = torch.tensor([0, 1, 2, 3], device=t.device)
|
||
|
||
# 判断每一行是否含有这些值
|
||
mask = ~((t[0][..., None] == exclude_vals).any(dim=(1, 2)))
|
||
|
||
# 过滤行并保留 batch 维
|
||
filtered_tensor = t[0][mask].unsqueeze(0)
|
||
|
||
return filtered_tensor
|
||
|
||
def __call__(self, generated_output, output_path=None):
|
||
config = TokenizerConfig(
|
||
use_time_signatures=True,
|
||
use_tempos=True,
|
||
use_velocities=True,
|
||
use_programs=True,
|
||
remove_duplicated_notes=True,
|
||
delete_equal_successive_tempo_changes=True,
|
||
)
|
||
config.additional_params["max_bar_embedding"] = 512
|
||
tokenizer = Octuple(config)
|
||
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
# generated_output = generated_output[:, 1:generated_output.shape[1]-1, :] # remove sos token
|
||
generated_output = self.remove_rows_with_exact_0_1_2_3(generated_output)
|
||
print(output_path)
|
||
try:
|
||
tok_seq = tokenizer.decode(generated_output.squeeze(0).tolist())
|
||
tok_seq.dump_midi(output_path)
|
||
except Exception as e:
|
||
print(generated_output)
|
||
print(f" × 生成 MIDI 文件时出错:{output_path} -> {e}")
|
||
tok_seq = None
|
||
|
||
return tok_seq
|