1021 add flexable attr control

This commit is contained in:
FelixChan
2025-10-21 15:27:03 +08:00
parent d6b68ef90b
commit b493ede479
15 changed files with 400 additions and 394 deletions

View File

@ -67,8 +67,19 @@ def get_best_ckpt_path_and_config(wandb_dir, code):
return last_ckpt_fn, config_path, metadata_path, vocab_path
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str):
def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str, vocab_path:str, condition_dataset: str=None):
# if config is a path, load it
if isinstance(config, (str, Path)):
from omegaconf import OmegaConf
config = OmegaConf.load(config)
config = wandb_style_config_to_omega_config(config)
nn_params = config.nn_params
for_evaluation = True
if condition_dataset is not None:
print(f"Conditioned dataset {condition_dataset} is used instead of {config.dataset}")
config.dataset = condition_dataset
for_evaluation = False
dataset_name = config.dataset
vocab_path = Path(vocab_path)
@ -104,7 +115,7 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
input_length=config.train_params.input_length,
first_pred_feature=config.data_params.first_pred_feature,
caption_path=config.captions_path if hasattr(config, 'captions_path') else None,
for_evaluation=True,
for_evaluation=for_evaluation
)
vocab_sizes = symbolic_dataset.vocab.get_vocab_size()
@ -114,7 +125,6 @@ def prepare_model_and_dataset_from_config(config: DictConfig, metadata_path:str,
split_ratio = config.data_params.split_ratio
# test_set = []
train_set, valid_set, test_set = symbolic_dataset.split_train_valid_test_set(dataset_name=config.dataset, ratio=split_ratio, seed=42, save_dir=None)
# get proper prediction order according to the encoding scheme and target feature in the config
prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params)
@ -480,6 +490,28 @@ class Evaluator:
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
def generate_samples_with_attrCtl(self, save_dir, num_target_measures, tuneidx, tune_name, first_pred_feature, sampling_method=None, threshold=None, temperature=1.0,generation_length=3072, attr_list=None):
encoding_scheme = self.config.nn_params.encoding_scheme
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[self.config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=self.vocab, in_beat_resolution=in_beat_resolution, dataset_name=self.config.dataset)
tuneidx = tuneidx.cuda()
generated_sample = self.model.generate(0, generation_length, condition=tuneidx, num_target_measures=num_target_measures, sampling_method=sampling_method, threshold=threshold, temperature=temperature, attr_list=attr_list)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
decoder(generated_sample, output_path=str(save_dir / f"{tune_name}.mid"))
prompt = self.model.decoder._prepare_inference(self.model.decoder.net.start_token, 0, tuneidx, num_target_measures=8)
decoder(prompt, output_path=str(save_dir / f"{tune_name}_prompt.mid"))
def generate_samples_unconditioned(self, save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
encoding_scheme = self.config.nn_params.encoding_scheme