1029 add octuple
This commit is contained in:
@ -8,9 +8,6 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed import init_process_group, destroy_process_group
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
import wandb
|
||||
import hydra
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
@ -20,6 +17,8 @@ from omegaconf import DictConfig, OmegaConf
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
from miditok import Octuple, TokenizerConfig
|
||||
|
||||
from Amadeus.symbolic_encoding import data_utils, decoding_utils
|
||||
from Amadeus.symbolic_encoding.data_utils import get_emb_total_size
|
||||
from Amadeus import model_zoo, trainer_accelerate as trainer
|
||||
@ -99,7 +98,7 @@ def preapre_sybmolic(config: DictConfig, save_dir: str, rank: int) -> trainer.La
|
||||
out_vocab_path = Path(save_dir) / f'vocab_{dataset_name}_{encoding_scheme}{num_features}.json'
|
||||
|
||||
# get vocab
|
||||
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
|
||||
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB', 'oct':'MusicTokenVocabOct'}
|
||||
selected_vocab_name = vocab_name[encoding_scheme]
|
||||
|
||||
vocab = getattr(vocab_utils, selected_vocab_name)(
|
||||
@ -159,7 +158,7 @@ def preapre_sybmolic(config: DictConfig, save_dir: str, rank: int) -> trainer.La
|
||||
focal_gamma = config.train_params.focal_gamma
|
||||
if encoding_scheme == 'remi':
|
||||
loss_fn = NLLLoss4REMI(focal_alpha=focal_alpha, focal_gamma=focal_gamma)
|
||||
elif encoding_scheme in ['cp', 'nb']:
|
||||
elif encoding_scheme in ['cp', 'nb', 'oct']:
|
||||
if config.use_diff is False:
|
||||
loss_fn = NLLLoss4CompoundToken(feature_list=symbolic_dataset.vocab.feature_list, focal_alpha=focal_alpha, focal_gamma=focal_gamma)
|
||||
else:
|
||||
@ -181,11 +180,11 @@ def preapre_sybmolic(config: DictConfig, save_dir: str, rank: int) -> trainer.La
|
||||
in_beat_resolution = in_beat_resolution_dict[dataset_name]
|
||||
except KeyError:
|
||||
in_beat_resolution = 4
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB', 'oct':'MidiDecoder4Octuple'}
|
||||
midi_decoder = getattr(decoding_utils, midi_decoder_dict[encoding_scheme])(vocab=symbolic_dataset.vocab, in_beat_resolution=in_beat_resolution, dataset_name=dataset_name)
|
||||
|
||||
# Select trainer class based on encoding scheme
|
||||
trainer_option_dict = {'remi': 'LanguageModelTrainer4REMI', 'cp': 'LanguageModelTrainer4CompoundToken', 'nb':'LanguageModelTrainer4CompoundToken'}
|
||||
trainer_option_dict = {'remi': 'LanguageModelTrainer4REMI', 'cp': 'LanguageModelTrainer4CompoundToken', 'nb':'LanguageModelTrainer4CompoundToken', 'oct':'LanguageModelTrainer4CompoundToken'}
|
||||
trainer_option = trainer_option_dict[encoding_scheme]
|
||||
sampling_method = None
|
||||
sampling_threshold = 0.99
|
||||
|
||||
Reference in New Issue
Block a user