1029 add octuple

This commit is contained in:
Mars
2025-10-29 17:14:33 +08:00
parent b493ede479
commit e16c84aab2
22 changed files with 1135 additions and 62 deletions

View File

@ -1,13 +1,17 @@
import re
import os, sys
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
import torch
from music21 import converter
import muspy
import miditoolkit
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature
from symusic import Score
from miditok import Octuple, TokenizerConfig
from .midi2audio import FluidSynth
from data_representation.constants import PROGRAM_INSTRUMENT_MAP
@ -400,5 +404,62 @@ class MidiDecoder4NB(MidiDecoder4REMI):
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
midi_obj.dump(output_path)
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
# save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
return midi_obj
class MidiDecoder4Octuple(MidiDecoder4REMI):
def __init__(self, vocab, in_beat_resolution, dataset_name):
super().__init__(vocab, in_beat_resolution, dataset_name)
def remove_rows_with_exact_0_1_2_3(self, t: torch.Tensor) -> torch.Tensor:
"""
输入:
t: torch.Tensor, 形状 (1, N, M)
功能:
删除包含独立元素 0, 1, 2, 3 的子tensor行
返回:
torch.Tensor, 同样保持 batch 维度 (1, N_filtered, M)
"""
if t.dim() != 3:
raise ValueError("输入 tensor 必须是三维 (batch, seq_len, feature)")
# 构造一个 maskTrue 表示该行不包含 0,1,2,3
exclude_vals = torch.tensor([0, 1, 2, 3], device=t.device)
# 判断每一行是否含有这些值
mask = ~((t[0][..., None] == exclude_vals).any(dim=(1, 2)))
# 过滤行并保留 batch 维
filtered_tensor = t[0][mask].unsqueeze(0)
return filtered_tensor
def __call__(self, generated_output, output_path=None):
config = TokenizerConfig(
use_time_signatures=True,
use_tempos=True,
use_velocities=True,
use_programs=True,
remove_duplicated_notes=True,
delete_equal_successive_tempo_changes=True,
)
config.additional_params["max_bar_embedding"] = 512
tokenizer = Octuple(config)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# generated_output = generated_output[:, 1:generated_output.shape[1]-1, :] # remove sos token
generated_output = self.remove_rows_with_exact_0_1_2_3(generated_output)
print(output_path)
try:
tok_seq = tokenizer.decode(generated_output.squeeze(0).tolist())
tok_seq.dump_midi(output_path)
except Exception as e:
print(generated_output)
print(f" × 生成 MIDI 文件时出错:{output_path} -> {e}")
tok_seq = None
return tok_seq