1029 add octuple
This commit is contained in:
@ -1,13 +1,17 @@
|
||||
import re
|
||||
import os, sys
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from music21 import converter
|
||||
import muspy
|
||||
import miditoolkit
|
||||
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature
|
||||
from symusic import Score
|
||||
from miditok import Octuple, TokenizerConfig
|
||||
|
||||
from .midi2audio import FluidSynth
|
||||
from data_representation.constants import PROGRAM_INSTRUMENT_MAP
|
||||
@ -400,5 +404,62 @@ class MidiDecoder4NB(MidiDecoder4REMI):
|
||||
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||||
midi_obj.dump(output_path)
|
||||
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
|
||||
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
|
||||
# save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
|
||||
return midi_obj
|
||||
|
||||
class MidiDecoder4Octuple(MidiDecoder4REMI):
|
||||
def __init__(self, vocab, in_beat_resolution, dataset_name):
|
||||
super().__init__(vocab, in_beat_resolution, dataset_name)
|
||||
|
||||
|
||||
|
||||
def remove_rows_with_exact_0_1_2_3(self, t: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
输入:
|
||||
t: torch.Tensor, 形状 (1, N, M)
|
||||
功能:
|
||||
删除包含独立元素 0, 1, 2, 3 的子tensor行
|
||||
返回:
|
||||
torch.Tensor, 同样保持 batch 维度 (1, N_filtered, M)
|
||||
"""
|
||||
if t.dim() != 3:
|
||||
raise ValueError("输入 tensor 必须是三维 (batch, seq_len, feature)")
|
||||
|
||||
# 构造一个 mask,True 表示该行不包含 0,1,2,3
|
||||
exclude_vals = torch.tensor([0, 1, 2, 3], device=t.device)
|
||||
|
||||
# 判断每一行是否含有这些值
|
||||
mask = ~((t[0][..., None] == exclude_vals).any(dim=(1, 2)))
|
||||
|
||||
# 过滤行并保留 batch 维
|
||||
filtered_tensor = t[0][mask].unsqueeze(0)
|
||||
|
||||
return filtered_tensor
|
||||
|
||||
def __call__(self, generated_output, output_path=None):
|
||||
config = TokenizerConfig(
|
||||
use_time_signatures=True,
|
||||
use_tempos=True,
|
||||
use_velocities=True,
|
||||
use_programs=True,
|
||||
remove_duplicated_notes=True,
|
||||
delete_equal_successive_tempo_changes=True,
|
||||
)
|
||||
config.additional_params["max_bar_embedding"] = 512
|
||||
tokenizer = Octuple(config)
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# generated_output = generated_output[:, 1:generated_output.shape[1]-1, :] # remove sos token
|
||||
generated_output = self.remove_rows_with_exact_0_1_2_3(generated_output)
|
||||
print(output_path)
|
||||
try:
|
||||
tok_seq = tokenizer.decode(generated_output.squeeze(0).tolist())
|
||||
tok_seq.dump_midi(output_path)
|
||||
except Exception as e:
|
||||
print(generated_output)
|
||||
print(f" × 生成 MIDI 文件时出错:{output_path} -> {e}")
|
||||
tok_seq = None
|
||||
|
||||
return tok_seq
|
||||
|
||||
Reference in New Issue
Block a user