Files
2025-10-29 17:14:33 +08:00

466 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import os, sys
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
import torch
from music21 import converter
import muspy
import miditoolkit
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature
from symusic import Score
from miditok import Octuple, TokenizerConfig
from .midi2audio import FluidSynth
from data_representation.constants import PROGRAM_INSTRUMENT_MAP
class MuteWarn:
def __enter__(self):
self._init_stdout = sys.stdout
sys.stdout = open(os.devnull, "w")
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout.close()
sys.stdout = self._init_stdout
def save_score_image_from_midi(midi_fn, file_name):
assert isinstance(midi_fn, str)
with MuteWarn():
convert = converter.parse(midi_fn)
convert.write('musicxml.png', fp=file_name)
def save_pianoroll_image_from_midi(midi_fn, file_name):
assert isinstance(midi_fn, str)
midi_obj_muspy = muspy.read_midi(midi_fn)
midi_obj_muspy.show_pianoroll(track_label='program', preset='frame')
plt.gcf().set_size_inches(20, 10)
plt.savefig(file_name)
plt.close()
def save_wav_from_midi(midi_fn, file_name, qpm=80):
assert isinstance(midi_fn, str)
with MuteWarn():
music = muspy.read_midi(midi_fn)
music.tempos[0].qpm = qpm
music.write_audio(file_name, rate=44100, gain=3)
def save_wav_from_midi_fluidsynth(midi_fn, file_name, gain=3):
assert isinstance(midi_fn, str)
fs = FluidSynth(gain=gain)
fs.midi_to_audio(midi_fn, file_name)
class MidiDecoder4REMI:
def __init__(
self,
vocab,
in_beat_resolution,
dataset_name
):
self.vocab = vocab
self.in_beat_resolution = in_beat_resolution
self.dataset_name = dataset_name
if dataset_name == 'SymphonyMIDI':
self.gain = 0.7
elif dataset_name == 'SOD' or dataset_name == 'LakhClean':
self.gain = 1.1
elif dataset_name == 'Pop1k7' or dataset_name == 'Pop909':
self.gain = 2.5
else:
self.gain = 1.5
def __call__(self, generated_output, output_path=None):
'''
generated_output: list of tensor, the tensor
'''
idx2event = self.vocab.idx2event
if generated_output.dim() == 2:
generated_output = generated_output.squeeze(0)
events = [idx2event[token.item()] for token in generated_output]
midi_obj = miditoolkit.midi.parser.MidiFile()
if 'tempo' not in idx2event.keys():
default_tempo = 95
midi_obj.tempo_changes.append(
TempoChange(tempo=default_tempo, time=0))
default_ticks_per_beat = 480
default_in_beat_ticks = 480 // self.in_beat_resolution
cur_pos = 0
bar_pos = 0
cur_bar_resol = 0
beat_pos = 0
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
instr_notes_dict = defaultdict(list)
for i in range(len(events)-2):
cur_event = events[i]
# print(cur_event)
name = cur_event.split('_')[0]
attr = cur_event.split('_')
if name == 'Bar':
bar_pos += cur_bar_resol
if 'time' in cur_event:
cur_num, cur_denom = attr[-1].split('/')
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
elif name == 'Beat':
beat_pos = int(attr[1])
cur_pos = bar_pos + beat_pos * default_in_beat_ticks
elif name == 'Chord':
chord_text = attr[1] + '_' + attr[2]
midi_obj.markers.append(Marker(text=chord_text, time=cur_pos))
elif name == 'Tempo':
midi_obj.tempo_changes.append(
TempoChange(tempo=int(attr[1]), time=cur_pos))
elif name == 'Instrument':
cur_instr = int(attr[1])
else:
if len(self.vocab.feature_list) == 7 or len(self.vocab.feature_list) == 8:
if 'Note_Pitch' in events[i] and \
'Note_Duration' in events[i+1] and \
'Note_Velocity' in events[i+2]:
pitch = int(events[i].split('_')[-1])
duration = int(events[i+1].split('_')[-1])
duration = duration * default_in_beat_ticks
end = cur_pos + duration
velocity = int(events[i+2].split('_')[-1])
instr_notes_dict[cur_instr].append(
Note(
pitch=pitch,
start=cur_pos,
end=end,
velocity=velocity))
elif len(self.vocab.feature_list) == 4 or len(self.vocab.feature_list) == 5:
if 'Note_Pitch' in events[i] and \
'Note_Duration' in events[i+1]:
pitch = int(events[i].split('_')[-1])
duration = int(events[i+1].split('_')[-1])
duration = duration * default_in_beat_ticks
end = cur_pos + duration
velocity = 90
instr_notes_dict[cur_instr].append(
Note(
pitch=pitch,
start=cur_pos,
end=end,
velocity=velocity))
# save midi
for instr, notes in instr_notes_dict.items():
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
if instr == 114: is_drum = True
else: is_drum = False
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
instr_track.notes = notes
midi_obj.instruments.append(instr_track)
if isinstance(output_path, str) or isinstance(output_path, Path):
output_path = str(output_path)
# make subdir
music_path = os.path.join(os.path.dirname(output_path), 'music')
prompt_music_path = os.path.join(os.path.dirname(output_path), 'prompt_music')
if not os.path.exists(music_path):
os.makedirs(music_path)
if not os.path.exists(prompt_music_path):
os.makedirs(prompt_music_path)
# if not contain 'prompt' in output_path, save prompt music
if 'prompt' in output_path:
music_path = os.path.join(prompt_music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
else:
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
midi_obj.dump(output_path)
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
return midi_obj
class MidiDecoder4CP(MidiDecoder4REMI):
def __init__(self, vocab, in_beat_resolution, dataset_name):
super().__init__(vocab, in_beat_resolution, dataset_name)
def _update_chord_tempo(self, midi_obj, cur_pos, token_with_7infos, feature2idx):
if len(feature2idx) == 7 or len(feature2idx) == 8:
# chord
if token_with_7infos[feature2idx['chord']] != 'CONTI' and token_with_7infos[feature2idx['chord']] != 0:
midi_obj.markers.append(
Marker(text=str(token_with_7infos[feature2idx['chord']]), time=cur_pos))
# tempo
if token_with_7infos[feature2idx['tempo']] != 'CONTI' and token_with_7infos[feature2idx['tempo']] != 0 and token_with_7infos[feature2idx['tempo']] != "Tempo_N_N":
tempo = int(token_with_7infos[feature2idx['tempo']].split('_')[-1])
midi_obj.tempo_changes.append(
TempoChange(tempo=tempo, time=cur_pos))
return midi_obj
elif len(feature2idx) == 4 or len(feature2idx) == 5:
return midi_obj
def __call__(self, generated_output, output_path=None):
'''
generated_output: tensor, batch x seq_len x num_types
num_types includes: type, tempo, chord,'beat, pitch, duration, velocity
'''
idx2event = self.vocab.idx2event
feature_keys = self.vocab.feature_list
feature2idx = {key: idx for idx, key in enumerate(feature_keys)}
midi_obj = miditoolkit.midi.parser.MidiFile()
if len(feature2idx) == 4 or len(feature2idx) == 5:
default_tempo = 95
midi_obj.tempo_changes.append(
TempoChange(tempo=default_tempo, time=0))
default_ticks_per_beat = 480
default_in_beat_ticks = 480 // self.in_beat_resolution
cur_pos = 0
bar_pos = 0
cur_bar_resol = 0
beat_pos = 0
instr_notes_dict = defaultdict(list)
generated_output = generated_output.squeeze(0)
for i in range(len(generated_output)):
token_with_7infos = []
for kidx, key in enumerate(feature_keys):
token_with_7infos.append(idx2event[key][generated_output[i][kidx].item()])
# type token
if 'time_signature' in token_with_7infos[feature2idx['type']]:
cur_num, cur_denom = token_with_7infos[feature2idx['type']].split('_')[-1].split('/')
bar_pos += cur_bar_resol
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
elif token_with_7infos[feature2idx['type']] == 'Metrical':
if 'time_signature' in token_with_7infos[feature2idx['beat']]:
cur_num, cur_denom = token_with_7infos[feature2idx['beat']].split('_')[-1].split('/')
bar_pos += cur_bar_resol
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
elif token_with_7infos[feature2idx['beat']] == 'Bar':
bar_pos += cur_bar_resol
elif 'Beat' in str(token_with_7infos[feature2idx['beat']]):
beat_pos = int(token_with_7infos[feature2idx['beat']].split('_')[1])
cur_pos = bar_pos + beat_pos * default_in_beat_ticks # ticks
# chord and tempo
midi_obj = self._update_chord_tempo(midi_obj, cur_pos, token_with_7infos, feature2idx)
elif token_with_7infos[feature2idx['type']] == 'Note':
# instrument token
if len(feature2idx) == 8 or len(feature2idx) == 5:
if token_with_7infos[feature2idx['instrument']] != 0 and token_with_7infos[feature2idx['instrument']] != 'CONTI':
cur_instr = int(token_with_7infos[feature2idx['instrument']].split('_')[-1])
else:
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
try:
pitch = token_with_7infos[feature2idx['pitch']].split('_')[-1]
duration = token_with_7infos[feature2idx['duration']].split('_')[-1]
duration = int(duration) * default_in_beat_ticks
if len(feature2idx) == 7 or len(feature2idx) == 8:
velocity = token_with_7infos[feature2idx['velocity']].split('_')[-1]
else:
velocity = 80
end = cur_pos + duration
instr_notes_dict[cur_instr].append(
Note(
pitch=int(pitch),
start=cur_pos,
end=end,
velocity=int(velocity))
)
except:
continue
else: # when new bar started without beat
continue
# save midi
for instr, notes in instr_notes_dict.items():
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
if instr == 114: is_drum = True
else: is_drum = False
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
instr_track.notes = notes
midi_obj.instruments.append(instr_track)
if isinstance(output_path, str) or isinstance(output_path, Path):
output_path = str(output_path)
output_music_dir = os.path.join(os.path.dirname(output_path), 'music')
if not os.path.exists(output_music_dir):
os.makedirs(output_music_dir)
midi_obj.dump(output_path)
save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, output_music_dir.replace('.mid', '.wav'), gain=self.gain)
return midi_obj
class MidiDecoder4NB(MidiDecoder4REMI):
def __init__(self, vocab, in_beat_resolution, dataset_name):
super().__init__(vocab, in_beat_resolution, dataset_name)
def _update_additional_info(self, midi_obj, cur_pos, token_with_7infos, feature2idx):
if len(feature2idx) == 7 or len(feature2idx) == 8:
# chord
if token_with_7infos[feature2idx['chord']] != 'CONTI' and token_with_7infos[feature2idx['chord']] != 0 and token_with_7infos[feature2idx['chord']] != 'Chord_N_N':
midi_obj.markers.append(
Marker(text=str(token_with_7infos[feature2idx['chord']]), time=cur_pos))
# tempo
if token_with_7infos[feature2idx['tempo']] != 'CONTI' and token_with_7infos[feature2idx['tempo']] != 0 and token_with_7infos[feature2idx['tempo']] != "Tempo_N_N":
tempo = int(token_with_7infos[feature2idx['tempo']].split('_')[-1])
midi_obj.tempo_changes.append(
TempoChange(tempo=tempo, time=cur_pos))
return midi_obj
elif len(feature2idx) == 4 or len(feature2idx) == 5:
return midi_obj
def __call__(self, generated_output, output_path=None):
'''
generated_output: tensor, batch x seq_len x num_types
num_types includes: type, beat, chord, tempo, intrument, pitch, duration, velocity
'''
idx2event = self.vocab.idx2event
feature_keys = self.vocab.feature_list
feature2idx = {key: idx for idx, key in enumerate(feature_keys)}
midi_obj = miditoolkit.midi.parser.MidiFile()
if len(feature2idx) == 4 or len(feature2idx) == 5:
default_tempo = 95
midi_obj.tempo_changes.append(
TempoChange(tempo=default_tempo, time=0))
default_ticks_per_beat = 480
default_in_beat_ticks = 480 // self.in_beat_resolution
cur_pos = 0
bar_pos = 0
cur_bar_resol = 0
beat_pos = 0
instr_notes_dict = defaultdict(list)
generated_output = generated_output.squeeze(0)
for i in range(len(generated_output)):
token_with_7infos = []
for kidx, key in enumerate(feature_keys):
token_with_7infos.append(idx2event[key][generated_output[i][kidx].item()])
# type token
if token_with_7infos[feature2idx['type']] == 'Empty_Bar' or token_with_7infos[feature2idx['type']] == 'SNN':
bar_pos += cur_bar_resol
elif 'NNN' in token_with_7infos[feature2idx['type']]:
cur_num, cur_denom = token_with_7infos[feature2idx['type']].split('_')[-1].split('/')
bar_pos += cur_bar_resol
new_bar_resol = int(default_ticks_per_beat * int(cur_num) * (4 / int(cur_denom)))
cur_bar_resol = new_bar_resol
midi_obj.time_signature_changes.append(
TimeSignature(numerator=int(cur_num), denominator=int(cur_denom), time=bar_pos))
# instrument token
if len(feature2idx) == 8 or len(feature2idx) == 5:
if token_with_7infos[feature2idx['instrument']] != 0 and token_with_7infos[feature2idx['instrument']] != 'CONTI':
cur_instr = int(token_with_7infos[feature2idx['instrument']].split('_')[-1])
else:
cur_instr = 0 if not self.dataset_name == 'BachChorale' else 53
if 'Beat' in str(token_with_7infos[feature2idx['beat']]) or 'CONTI' in str(token_with_7infos[feature2idx['beat']]):
if 'Beat' in str(token_with_7infos[feature2idx['beat']]): # when beat is not CONTI beat is updated
beat_pos = int(token_with_7infos[feature2idx['beat']].split('_')[1])
cur_pos = bar_pos + beat_pos * default_in_beat_ticks # ticks
# update chord and tempo
midi_obj = self._update_additional_info(midi_obj, cur_pos, token_with_7infos, feature2idx)
# note
try:
pitch = token_with_7infos[feature2idx['pitch']].split('_')[-1]
duration = token_with_7infos[feature2idx['duration']].split('_')[-1] # duration between 1~192
duration = int(duration) * default_in_beat_ticks
if len(feature2idx) == 7 or len(feature2idx) == 8:
velocity = token_with_7infos[feature2idx['velocity']].split('_')[-1]
else:
velocity = 90
end = cur_pos + duration
instr_notes_dict[cur_instr].append(
Note(
pitch=int(pitch),
start=cur_pos,
end=end,
velocity=int(velocity))
)
except:
continue
else: # when new bar started without beat
continue
# save midi
for instr, notes in instr_notes_dict.items():
instrument_name = PROGRAM_INSTRUMENT_MAP[instr]
if instr == 114: is_drum = True
else: is_drum = False
instr_track = Instrument(instr, is_drum=is_drum, name=instrument_name)
instr_track.notes = notes
midi_obj.instruments.append(instr_track)
if isinstance(output_path, str) or isinstance(output_path, Path):
output_path = str(output_path)
music_path = os.path.join(os.path.dirname(output_path), 'music')
prompt_music_path = os.path.join(os.path.dirname(output_path), 'prompt_music')
if not os.path.exists(music_path):
os.makedirs(music_path)
if not os.path.exists(prompt_music_path):
os.makedirs(prompt_music_path)
# if not contain 'prompt' in output_path, save prompt music
if 'prompt' in output_path:
music_path = os.path.join(prompt_music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
else:
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
midi_obj.dump(output_path)
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
# save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
return midi_obj
class MidiDecoder4Octuple(MidiDecoder4REMI):
def __init__(self, vocab, in_beat_resolution, dataset_name):
super().__init__(vocab, in_beat_resolution, dataset_name)
def remove_rows_with_exact_0_1_2_3(self, t: torch.Tensor) -> torch.Tensor:
"""
输入:
t: torch.Tensor, 形状 (1, N, M)
功能:
删除包含独立元素 0, 1, 2, 3 的子tensor行
返回:
torch.Tensor, 同样保持 batch 维度 (1, N_filtered, M)
"""
if t.dim() != 3:
raise ValueError("输入 tensor 必须是三维 (batch, seq_len, feature)")
# 构造一个 maskTrue 表示该行不包含 0,1,2,3
exclude_vals = torch.tensor([0, 1, 2, 3], device=t.device)
# 判断每一行是否含有这些值
mask = ~((t[0][..., None] == exclude_vals).any(dim=(1, 2)))
# 过滤行并保留 batch 维
filtered_tensor = t[0][mask].unsqueeze(0)
return filtered_tensor
def __call__(self, generated_output, output_path=None):
config = TokenizerConfig(
use_time_signatures=True,
use_tempos=True,
use_velocities=True,
use_programs=True,
remove_duplicated_notes=True,
delete_equal_successive_tempo_changes=True,
)
config.additional_params["max_bar_embedding"] = 512
tokenizer = Octuple(config)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# generated_output = generated_output[:, 1:generated_output.shape[1]-1, :] # remove sos token
generated_output = self.remove_rows_with_exact_0_1_2_3(generated_output)
print(output_path)
try:
tok_seq = tokenizer.decode(generated_output.squeeze(0).tolist())
tok_seq.dump_midi(output_path)
except Exception as e:
print(generated_output)
print(f" × 生成 MIDI 文件时出错:{output_path} -> {e}")
tok_seq = None
return tok_seq