1029 add octuple

This commit is contained in:
Mars
2025-10-29 17:14:33 +08:00
parent b493ede479
commit e16c84aab2
22 changed files with 1135 additions and 62 deletions

View File

@ -8,9 +8,6 @@ import torch
import torch.multiprocessing as mp
from torch.distributed import init_process_group, destroy_process_group
from accelerate import Accelerator
from accelerate.utils import set_seed
import wandb
import hydra
from hydra.core.hydra_config import HydraConfig
@ -20,6 +17,8 @@ from omegaconf import DictConfig, OmegaConf
from accelerate import Accelerator
from accelerate.utils import set_seed
from miditok import Octuple, TokenizerConfig
from Amadeus.symbolic_encoding import data_utils, decoding_utils
from Amadeus.symbolic_encoding.data_utils import get_emb_total_size
from Amadeus import model_zoo, trainer_accelerate as trainer
@ -99,7 +98,7 @@ def preapre_sybmolic(config: DictConfig, save_dir: str, rank: int) -> trainer.La
out_vocab_path = Path(save_dir) / f'vocab_{dataset_name}_{encoding_scheme}{num_features}.json'
# get vocab
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB', 'oct':'MusicTokenVocabOct'}
selected_vocab_name = vocab_name[encoding_scheme]
vocab = getattr(vocab_utils, selected_vocab_name)(
@ -159,7 +158,7 @@ def preapre_sybmolic(config: DictConfig, save_dir: str, rank: int) -> trainer.La
focal_gamma = config.train_params.focal_gamma
if encoding_scheme == 'remi':
loss_fn = NLLLoss4REMI(focal_alpha=focal_alpha, focal_gamma=focal_gamma)
elif encoding_scheme in ['cp', 'nb']:
elif encoding_scheme in ['cp', 'nb', 'oct']:
if config.use_diff is False:
loss_fn = NLLLoss4CompoundToken(feature_list=symbolic_dataset.vocab.feature_list, focal_alpha=focal_alpha, focal_gamma=focal_gamma)
else:
@ -181,11 +180,11 @@ def preapre_sybmolic(config: DictConfig, save_dir: str, rank: int) -> trainer.La
in_beat_resolution = in_beat_resolution_dict[dataset_name]
except KeyError:
in_beat_resolution = 4
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'}
midi_decoder = getattr(decoding_utils, midi_decoder_dict[encoding_scheme])(vocab=symbolic_dataset.vocab, in_beat_resolution=in_beat_resolution, dataset_name=dataset_name)
# Select trainer class based on encoding scheme
trainer_option_dict = {'remi': 'LanguageModelTrainer4REMI', 'cp': 'LanguageModelTrainer4CompoundToken', 'nb':'LanguageModelTrainer4CompoundToken'}
trainer_option_dict = {'remi': 'LanguageModelTrainer4REMI', 'cp': 'LanguageModelTrainer4CompoundToken', 'nb':'LanguageModelTrainer4CompoundToken', 'oct':'LanguageModelTrainer4CompoundToken'}
trainer_option = trainer_option_dict[encoding_scheme]
sampling_method = None
sampling_threshold = 0.99