1029 add octuple
This commit is contained in:
@ -22,6 +22,8 @@ class Augmentor:
|
||||
self.chord_idx = self.feature_list.index('chord')
|
||||
|
||||
def _get_shift(self, segment):
|
||||
if self.encoding_scheme == 'oct':
|
||||
return 0
|
||||
# the pitch vocab has ignore token in 0 index
|
||||
if self.encoding_scheme == 'cp' or self.encoding_scheme == 'nb':
|
||||
pitch_mask = segment != 0
|
||||
|
||||
@ -73,7 +73,7 @@ class VanillaTransformer_compiler():
|
||||
for i in range(len(self.data_list)):
|
||||
tune_in_idx, tune_name = self.data_list[i]
|
||||
tune_in_idx = torch.LongTensor(tune_in_idx)
|
||||
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
|
||||
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp' or self.encoding_scheme == 'oct':
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
else:
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
@ -148,7 +148,7 @@ class VanillaTransformer_compiler():
|
||||
for i in range(len(self.data_list)):
|
||||
tune_in_idx, tune_name = self.data_list[i]
|
||||
tune_in_idx = torch.LongTensor(tune_in_idx)
|
||||
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
|
||||
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp' or self.encoding_scheme == 'oct':
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
else:
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
|
||||
@ -95,11 +95,6 @@ class TuneCompiler(Dataset):
|
||||
print(f"Error encoding caption for tune {tune_name}: {e}")
|
||||
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
|
||||
return segment, tensor_mask, tune_name, encoded_caption
|
||||
if self.data_type == 'train':
|
||||
augmented_segment = self.augmentor(segment)
|
||||
return augmented_segment, tensor_mask, tune_name, encoded_caption
|
||||
else:
|
||||
return segment, tensor_mask, tune_name, encoded_caption
|
||||
|
||||
def get_segments_with_tune_idx(self, tune_name, seg_order):
|
||||
'''
|
||||
@ -135,6 +130,7 @@ class IterTuneCompiler(IterableDataset):
|
||||
self.data_type = data_type
|
||||
self.augmentor = augmentor
|
||||
self.eos_token = vocab.eos_token
|
||||
self.vocab = vocab
|
||||
self.compile_function = VanillaTransformer_compiler(
|
||||
data_list=self.data_list,
|
||||
augmentor=self.augmentor,
|
||||
@ -157,7 +153,7 @@ class IterTuneCompiler(IterableDataset):
|
||||
encoded_caption = self.t5_tokenizer(tune_name, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
|
||||
except Exception as e:
|
||||
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
|
||||
if self.data_type == 'train':
|
||||
if self.data_type == 'train' and self.vocab.encoding_scheme != 'oct':
|
||||
segment = self.augmentor(segment)
|
||||
# use input_ids replace tune_name
|
||||
tune_name = encoded_caption['input_ids'][0] # Use the input_ids from the encoded caption
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
import re
|
||||
import os, sys
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from music21 import converter
|
||||
import muspy
|
||||
import miditoolkit
|
||||
from miditoolkit.midi.containers import Marker, Instrument, TempoChange, Note, TimeSignature
|
||||
from symusic import Score
|
||||
from miditok import Octuple, TokenizerConfig
|
||||
|
||||
from .midi2audio import FluidSynth
|
||||
from data_representation.constants import PROGRAM_INSTRUMENT_MAP
|
||||
@ -400,5 +404,62 @@ class MidiDecoder4NB(MidiDecoder4REMI):
|
||||
music_path = os.path.join(music_path, output_path.split('/')[-1].replace('.mid', '.wav'))
|
||||
midi_obj.dump(output_path)
|
||||
# save_pianoroll_image_from_midi(output_path, output_path.replace('.mid', '.png'))
|
||||
save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
|
||||
# save_wav_from_midi_fluidsynth(output_path, music_path, gain=self.gain)
|
||||
return midi_obj
|
||||
|
||||
class MidiDecoder4Octuple(MidiDecoder4REMI):
|
||||
def __init__(self, vocab, in_beat_resolution, dataset_name):
|
||||
super().__init__(vocab, in_beat_resolution, dataset_name)
|
||||
|
||||
|
||||
|
||||
def remove_rows_with_exact_0_1_2_3(self, t: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
输入:
|
||||
t: torch.Tensor, 形状 (1, N, M)
|
||||
功能:
|
||||
删除包含独立元素 0, 1, 2, 3 的子tensor行
|
||||
返回:
|
||||
torch.Tensor, 同样保持 batch 维度 (1, N_filtered, M)
|
||||
"""
|
||||
if t.dim() != 3:
|
||||
raise ValueError("输入 tensor 必须是三维 (batch, seq_len, feature)")
|
||||
|
||||
# 构造一个 mask,True 表示该行不包含 0,1,2,3
|
||||
exclude_vals = torch.tensor([0, 1, 2, 3], device=t.device)
|
||||
|
||||
# 判断每一行是否含有这些值
|
||||
mask = ~((t[0][..., None] == exclude_vals).any(dim=(1, 2)))
|
||||
|
||||
# 过滤行并保留 batch 维
|
||||
filtered_tensor = t[0][mask].unsqueeze(0)
|
||||
|
||||
return filtered_tensor
|
||||
|
||||
def __call__(self, generated_output, output_path=None):
|
||||
config = TokenizerConfig(
|
||||
use_time_signatures=True,
|
||||
use_tempos=True,
|
||||
use_velocities=True,
|
||||
use_programs=True,
|
||||
remove_duplicated_notes=True,
|
||||
delete_equal_successive_tempo_changes=True,
|
||||
)
|
||||
config.additional_params["max_bar_embedding"] = 512
|
||||
tokenizer = Octuple(config)
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# generated_output = generated_output[:, 1:generated_output.shape[1]-1, :] # remove sos token
|
||||
generated_output = self.remove_rows_with_exact_0_1_2_3(generated_output)
|
||||
print(output_path)
|
||||
try:
|
||||
tok_seq = tokenizer.decode(generated_output.squeeze(0).tolist())
|
||||
tok_seq.dump_midi(output_path)
|
||||
except Exception as e:
|
||||
print(generated_output)
|
||||
print(f" × 生成 MIDI 文件时出错:{output_path} -> {e}")
|
||||
tok_seq = None
|
||||
|
||||
return tok_seq
|
||||
|
||||
Reference in New Issue
Block a user