1021 add flexable attr control
This commit is contained in:
@ -67,8 +67,19 @@ def get_best_ckpt_path_and_config(wandb_dir, code):
|
||||
|
||||
return last_ckpt_fn, config_path, metadata_path, vocab_path
|
||||
|
||||
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str):
|
||||
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str, condition_dataset: str=None):
|
||||
# if config is a path, load it
|
||||
if isinstance(config, (str, Path)):
|
||||
from omegaconf import OmegaConf
|
||||
config = OmegaConf.load(config)
|
||||
config = wandb_style_config_to_omega_config(config)
|
||||
|
||||
nn_params = config.nn_params
|
||||
for_evaluation = True
|
||||
if condition_dataset is not None:
|
||||
print(f"Conditioned dataset {condition_dataset} is used instead of {config.dataset}")
|
||||
config.dataset = condition_dataset
|
||||
for_evaluation = False
|
||||
dataset_name = config.dataset
|
||||
vocab_path = Path(vocab_path)
|
||||
|
||||
@ -104,7 +115,7 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
|
||||
input_length=config.train_params.input_length,
|
||||
first_pred_feature=config.data_params.first_pred_feature,
|
||||
caption_path=config.captions_path if hasattr(config, 'captions_path') else None,
|
||||
for_evaluation=True,
|
||||
for_evaluation=for_evaluation
|
||||
)
|
||||
|
||||
vocab_sizes = symbolic_dataset.vocab.get_vocab_size()
|
||||
@ -114,7 +125,6 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
|
||||
split_ratio = config.data_params.split_ratio
|
||||
# test_set = []
|
||||
train_set, valid_set, test_set = symbolic_dataset.split_train_valid_test_set(dataset_name=config.dataset, ratio=split_ratio, seed=42, save_dir=None)
|
||||
|
||||
# get proper prediction order according to the encoding scheme and target feature in the config
|
||||
prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params)
|
||||
|
||||
@ -480,6 +490,28 @@ class Evaluator:
|
||||
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
|
||||
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
|
||||
|
||||
def generate_samples_with_attrCtl(self, save_dir, num_target_measures, tuneidx, tune_name, first_pred_feature, sampling_method=None, threshold=None, temperature=1.0,generation_length=3072, attr_list=None):
|
||||
encoding_scheme = self.config.nn_params.encoding_scheme
|
||||
|
||||
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||
try:
|
||||
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
|
||||
except KeyError:
|
||||
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
|
||||
|
||||
tuneidx = tuneidx.cuda()
|
||||
generated_sample = self.model.generate(0, generation_length, condition=tuneidx, num_target_measures=num_target_measures, sampling_method=sampling_method, threshold=threshold, temperature=temperature, attr_list=attr_list)
|
||||
if encoding_scheme == 'nb':
|
||||
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||
decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid"))
|
||||
|
||||
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
|
||||
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
|
||||
|
||||
def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
|
||||
encoding_scheme = self.config.nn_params.encoding_scheme
|
||||
|
||||
|
||||
Reference in New Issue
Block a user