first commit
This commit is contained in:
210
generate.py
Normal file
210
generate.py
Normal file
@ -0,0 +1,210 @@
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
from transformers import T5Tokenizer, T5EncoderModel
|
||||
from Amadeus.train_utils import adjust_prediction_order
|
||||
|
||||
from Amadeus.evaluation_utils import (
|
||||
get_dir_from_wandb_by_code,
|
||||
wandb_style_config_to_omega_config,
|
||||
)
|
||||
from Amadeus.symbolic_encoding import decoding_utils, data_utils
|
||||
from data_representation import vocab_utils
|
||||
from Amadeus import model_zoo
|
||||
from Amadeus.symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor
|
||||
|
||||
|
||||
def get_argument_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-wandb_exp_dir",
|
||||
required=True,
|
||||
type=str,
|
||||
help="wandb experiment directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-prompt",
|
||||
required=True,
|
||||
type=str,
|
||||
help="text prompt for genvidia-smiration",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-output_dir",
|
||||
type=str,
|
||||
default="outputs",
|
||||
help="directory to save results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-sampling_method",
|
||||
type=str,
|
||||
choices=('top_p', 'top_k'),
|
||||
default='top_p',
|
||||
help="sampling method",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-threshold",
|
||||
type=float,
|
||||
default=0.99,
|
||||
help="threshold",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-temperature",
|
||||
type=float,
|
||||
default=1.15,
|
||||
help="temperature",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-generate_length",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="length of the generated sequence",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-text_encoder_model",
|
||||
type=str,
|
||||
default='google/flan-t5-large',
|
||||
help="pretrained text encoder model",
|
||||
)
|
||||
return parser
|
||||
|
||||
def get_best_ckpt_path_and_config(dir):
|
||||
if dir is None:
|
||||
raise ValueError('No such code in wandb_dir')
|
||||
ckpt_dir = dir / 'files' / 'checkpoints'
|
||||
|
||||
config_path = dir / 'files' / 'config.yaml'
|
||||
# print all files in ckpt_dir
|
||||
vocab_path = next(ckpt_dir.glob('vocab*'))
|
||||
|
||||
# if there is pt file ending with 'last', return it
|
||||
if len(list(ckpt_dir.glob('*last.pt'))) > 0:
|
||||
last_ckpt_fn = next(ckpt_dir.glob('*last.pt'))
|
||||
else:
|
||||
pt_fns = sorted(list(ckpt_dir.glob('*.pt')), key=lambda fn: int(fn.stem.split('_')[0].replace('iter', '')))
|
||||
last_ckpt_fn = pt_fns[-1]
|
||||
|
||||
return last_ckpt_fn, config_path, vocab_path
|
||||
|
||||
def prepare_model_and_dataset_from_config(config: DictConfig, vocab_path:str):
|
||||
nn_params = config.nn_params
|
||||
vocab_path = Path(vocab_path)
|
||||
|
||||
# print(config)
|
||||
encoding_scheme = config.nn_params.encoding_scheme
|
||||
num_features = config.nn_params.num_features
|
||||
|
||||
# get vocab
|
||||
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
|
||||
selected_vocab_name = vocab_name[encoding_scheme]
|
||||
|
||||
vocab = getattr(vocab_utils, selected_vocab_name)(
|
||||
in_vocab_file_path=vocab_path,
|
||||
event_data=None,
|
||||
encoding_scheme=encoding_scheme,
|
||||
num_features=num_features)
|
||||
# get proper prediction order according to the encoding scheme and target feature in the config
|
||||
prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params)
|
||||
|
||||
# Create the Transformer model based on configuration parameters
|
||||
AmadeusModel = getattr(model_zoo, nn_params.model_name)(
|
||||
vocab=vocab,
|
||||
input_length=config.train_params.input_length,
|
||||
prediction_order=prediction_order,
|
||||
input_embedder_name=nn_params.input_embedder_name,
|
||||
main_decoder_name=nn_params.main_decoder_name,
|
||||
sub_decoder_name=nn_params.sub_decoder_name,
|
||||
sub_decoder_depth=nn_params.sub_decoder.num_layer if hasattr(nn_params, 'sub_decoder') else 0,
|
||||
sub_decoder_enricher_use=nn_params.sub_decoder.feature_enricher_use \
|
||||
if hasattr(nn_params, 'sub_decoder') and hasattr(nn_params.sub_decoder, 'feature_enricher_use') else False,
|
||||
dim=nn_params.main_decoder.dim_model,
|
||||
heads=nn_params.main_decoder.num_head,
|
||||
depth=nn_params.main_decoder.num_layer,
|
||||
dropout=nn_params.model_dropout,
|
||||
)
|
||||
|
||||
return AmadeusModel, [], vocab
|
||||
|
||||
def load_resources(dir, device):
|
||||
"""Load model and dataset resources"""
|
||||
dir = Path(dir)
|
||||
ckpt_path, config_path, vocab_path = get_best_ckpt_path_and_config(
|
||||
dir
|
||||
)
|
||||
config = OmegaConf.load(config_path)
|
||||
config = wandb_style_config_to_omega_config(config)
|
||||
|
||||
ckpt = torch.load(ckpt_path, map_location=device)
|
||||
model, _, vocab = prepare_model_and_dataset_from_config(config, vocab_path)
|
||||
model.load_state_dict(ckpt['model'], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
torch.compile(model)
|
||||
print("total parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
|
||||
|
||||
return config, model, vocab
|
||||
|
||||
def generate_with_text_prompt(config, vocab, model, device, prompt, save_dir,
|
||||
first_pred_feature, sampling_method, threshold,
|
||||
temperature, generation_length=1024):
|
||||
encoding_scheme = config.nn_params.encoding_scheme
|
||||
tokenizer = T5Tokenizer.from_pretrained(config.text_encoder_model)
|
||||
encoder = T5EncoderModel.from_pretrained(config.text_encoder_model).to(device)
|
||||
print(f"Using T5EncoderModel for text prompt:\n{prompt}")
|
||||
context = tokenizer(prompt, return_tensors='pt',
|
||||
padding='max_length', truncation=True, max_length=128).to(device)
|
||||
context = encoder(**context).last_hidden_state
|
||||
|
||||
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||
in_beat_resolution = in_beat_resolution_dict.get(config.dataset, 4)
|
||||
|
||||
midi_decoder_dict = {'remi': 'MidiDecoder4REMI',
|
||||
'cp': 'MidiDecoder4CP',
|
||||
'nb': 'MidiDecoder4NB'}
|
||||
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||
decoder = getattr(decoding_utils, decoder_name)(
|
||||
vocab=vocab, in_beat_resolution=in_beat_resolution, dataset_name=config.dataset
|
||||
)
|
||||
|
||||
generated_sample = model.generate(
|
||||
0, generation_length, condition=None, num_target_measures=None,
|
||||
sampling_method=sampling_method, threshold=threshold,
|
||||
temperature=temperature, context=context
|
||||
)
|
||||
if encoding_scheme == 'nb':
|
||||
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
output_file = save_dir / f"generated.mid"
|
||||
decoder(generated_sample, output_path=str(output_file))
|
||||
print(f"Generated file saved at: {output_file}")
|
||||
|
||||
|
||||
def main():
|
||||
args = get_argument_parser().parse_args()
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
config, model, vocab = load_resources(args.wandb_exp_dir, device)
|
||||
|
||||
save_dir = Path(args.output_dir)
|
||||
config.text_encoder_model = args.text_encoder_model
|
||||
generate_with_text_prompt(
|
||||
config,
|
||||
vocab,
|
||||
model,
|
||||
device,
|
||||
args.prompt,
|
||||
save_dir,
|
||||
config.data_params.first_pred_feature,
|
||||
args.sampling_method,
|
||||
args.threshold,
|
||||
args.temperature,
|
||||
generation_length=args.generate_length,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user