1029 add octuple

This commit is contained in:
Mars
2025-10-29 17:14:33 +08:00
parent b493ede479
commit e16c84aab2
22 changed files with 1135 additions and 62 deletions

View File

@ -22,6 +22,8 @@ class Augmentor:
self.chord_idx = self.feature_list.index('chord')
def _get_shift(self, segment):
if self.encoding_scheme == 'oct':
return 0
# the pitch vocab has ignore token in 0 index
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
pitch_mask = segment != 0

View File

@ -73,7 +73,7 @@ class VanillaTransformer_compiler():
for i in range(len(self.data_list)):
tune_in_idx, tune_name = self.data_list[i]
tune_in_idx = torch.LongTensor(tune_in_idx)
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp' or self.encoding_scheme == 'oct':
eos_token = torch.LongTensor(self.eos_token)
else:
eos_token = torch.LongTensor(self.eos_token)
@ -148,7 +148,7 @@ class VanillaTransformer_compiler():
for i in range(len(self.data_list)):
tune_in_idx, tune_name = self.data_list[i]
tune_in_idx = torch.LongTensor(tune_in_idx)
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp' or self.encoding_scheme == 'oct':
eos_token = torch.LongTensor(self.eos_token)
else:
eos_token = torch.LongTensor(self.eos_token)

View File

@ -95,11 +95,6 @@ class TuneCompiler(Dataset):
print(f"Error encoding caption for tune {tune_name}: {e}")
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
return segment, tensor_mask, tune_name, encoded_caption
if self.data_type == 'train':
augmented_segment = self.augmentor(segment)
return augmented_segment, tensor_mask, tune_name, encoded_caption
else:
return segment, tensor_mask, tune_name, encoded_caption
def get_segments_with_tune_idx(self, tune_name, seg_order):
'''
@ -135,6 +130,7 @@ class IterTuneCompiler(IterableDataset):
self.data_type = data_type
self.augmentor = augmentor
self.eos_token = vocab.eos_token
self.vocab = vocab
self.compile_function = VanillaTransformer_compiler(
data_list=self.data_list,
augmentor=self.augmentor,
@ -157,7 +153,7 @@ class IterTuneCompiler(IterableDataset):
encoded_caption = self.t5_tokenizer(tune_name, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
except Exception as e:
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
if self.data_type == 'train':
if self.data_type == 'train' and self.vocab.encoding_scheme != 'oct':
segment = self.augmentor(segment)
# use input_ids replace tune_name
tune_name = encoded_caption['input_ids'][0] # Use the input_ids from the encoded caption

View File

@ -1,13 +1,17 @@
import re
import os, sys
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
import torch
from music21 import converter
import muspy
import miditoolkit
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature
from symusic import Score
from miditok import Octuple, TokenizerConfig
from .midi2audio import FluidSynth
from data_representation.constants import PROGRAM_INSTRUMENT_MAP
@ -400,5 +404,62 @@ class MidiDecoder4NB(MidiDecoder4REMI):
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
midi_obj.dump(output_path)
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
# save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
return midi_obj
class MidiDecoder4Octuple(MidiDecoder4REMI):
def __init__(self, vocab, in_beat_resolution, dataset_name):
super().__init__(vocab, in_beat_resolution, dataset_name)
def remove_rows_with_exact_0_1_2_3(self, t: torch.Tensor) -> torch.Tensor:
"""
输入:
t: torch.Tensor, 形状 (1, N, M)
功能:
删除包含独立元素 0, 1, 2, 3 的子tensor行
返回:
torch.Tensor, 同样保持 batch 维度 (1, N_filtered, M)
"""
if t.dim() != 3:
raise ValueError("输入 tensor 必须是三维 (batch, seq_len, feature)")
# 构造一个 maskTrue 表示该行不包含 0,1,2,3
exclude_vals = torch.tensor([0, 1, 2, 3], device=t.device)
# 判断每一行是否含有这些值
mask = ~((t[0][..., None] == exclude_vals).any(dim=(1, 2)))
# 过滤行并保留 batch 维
filtered_tensor = t[0][mask].unsqueeze(0)
return filtered_tensor
def __call__(self, generated_output, output_path=None):
config = TokenizerConfig(
use_time_signatures=True,
use_tempos=True,
use_velocities=True,
use_programs=True,
remove_duplicated_notes=True,
delete_equal_successive_tempo_changes=True,
)
config.additional_params["max_bar_embedding"] = 512
tokenizer = Octuple(config)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# generated_output = generated_output[:, 1:generated_output.shape[1]-1, :] # remove sos token
generated_output = self.remove_rows_with_exact_0_1_2_3(generated_output)
print(output_path)
try:
tok_seq = tokenizer.decode(generated_output.squeeze(0).tolist())
tok_seq.dump_midi(output_path)
except Exception as e:
print(generated_output)
print(f" × 生成 MIDI 文件时出错:{output_path} -> {e}")
tok_seq = None
return tok_seq