Files
MIDIFoundationModel/generate.py
2025-09-08 14:49:28 +08:00

210 lines
7.4 KiB
Python

import torch
from pathlib import Path
import argparse
import json
from collections import defaultdict
from omegaconf import OmegaConf, DictConfig
from transformers import T5Tokenizer, T5EncoderModel
from Amadeus.train_utils import adjust_prediction_order
from Amadeus.evaluation_utils import (
get_dir_from_wandb_by_code,
wandb_style_config_to_omega_config,
)
from Amadeus.symbolic_encoding import decoding_utils, data_utils
from data_representation import vocab_utils
from Amadeus import model_zoo
from Amadeus.symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-wandb_exp_dir",
required=True,
type=str,
help="wandb experiment directory",
)
parser.add_argument(
"-prompt",
required=True,
type=str,
help="text prompt for genvidia-smiration",
)
parser.add_argument(
"-output_dir",
type=str,
default="outputs",
help="directory to save results",
)
parser.add_argument(
"-sampling_method",
type=str,
choices=('top_p', 'top_k'),
default='top_p',
help="sampling method",
)
parser.add_argument(
"-threshold",
type=float,
default=0.99,
help="threshold",
)
parser.add_argument(
"-temperature",
type=float,
default=1.15,
help="temperature",
)
parser.add_argument(
"-generate_length",
type=int,
default=2048,
help="length of the generated sequence",
)
parser.add_argument(
"-text_encoder_model",
type=str,
default='google/flan-t5-large',
help="pretrained text encoder model",
)
return parser
def get_best_ckpt_path_and_config(dir):
if dir is None:
raise ValueError('No such code in wandb_dir')
ckpt_dir = dir / 'files' / 'checkpoints'
config_path = dir / 'files' / 'config.yaml'
# print all files in ckpt_dir
vocab_path = next(ckpt_dir.glob('vocab*'))
# if there is pt file ending with 'last', return it
if len(list(ckpt_dir.glob('*last.pt'))) > 0:
last_ckpt_fn = next(ckpt_dir.glob('*last.pt'))
else:
pt_fns = sorted(list(ckpt_dir.glob('*.pt')), key=lambda fn: int(fn.stem.split('_')[0].replace('iter', '')))
last_ckpt_fn = pt_fns[-1]
return last_ckpt_fn, config_path, vocab_path
def prepare_model_and_dataset_from_config(config: DictConfig, vocab_path:str):
nn_params = config.nn_params
vocab_path = Path(vocab_path)
# print(config)
encoding_scheme = config.nn_params.encoding_scheme
num_features = config.nn_params.num_features
# get vocab
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
selected_vocab_name = vocab_name[encoding_scheme]
vocab = getattr(vocab_utils, selected_vocab_name)(
in_vocab_file_path=vocab_path,
event_data=None,
encoding_scheme=encoding_scheme,
num_features=num_features)
# get proper prediction order according to the encoding scheme and target feature in the config
prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params)
# Create the Transformer model based on configuration parameters
AmadeusModel = getattr(model_zoo, nn_params.model_name)(
vocab=vocab,
input_length=config.train_params.input_length,
prediction_order=prediction_order,
input_embedder_name=nn_params.input_embedder_name,
main_decoder_name=nn_params.main_decoder_name,
sub_decoder_name=nn_params.sub_decoder_name,
sub_decoder_depth=nn_params.sub_decoder.num_layer if hasattr(nn_params, 'sub_decoder') else 0,
sub_decoder_enricher_use=nn_params.sub_decoder.feature_enricher_use \
if hasattr(nn_params, 'sub_decoder') and hasattr(nn_params.sub_decoder, 'feature_enricher_use') else False,
dim=nn_params.main_decoder.dim_model,
heads=nn_params.main_decoder.num_head,
depth=nn_params.main_decoder.num_layer,
dropout=nn_params.model_dropout,
)
return AmadeusModel, [], vocab
def load_resources(dir, device):
"""Load model and dataset resources"""
dir = Path(dir)
ckpt_path, config_path, vocab_path = get_best_ckpt_path_and_config(
dir
)
config = OmegaConf.load(config_path)
config = wandb_style_config_to_omega_config(config)
ckpt = torch.load(ckpt_path, map_location=device)
model, _, vocab = prepare_model_and_dataset_from_config(config, vocab_path)
model.load_state_dict(ckpt['model'], strict=False)
model.to(device)
model.eval()
torch.compile(model)
print("total parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
return config, model, vocab
def generate_with_text_prompt(config, vocab, model, device, prompt, save_dir,
first_pred_feature, sampling_method, threshold,
temperature, generation_length=1024):
encoding_scheme = config.nn_params.encoding_scheme
tokenizer = T5Tokenizer.from_pretrained(config.text_encoder_model)
encoder = T5EncoderModel.from_pretrained(config.text_encoder_model).to(device)
print(f"Using T5EncoderModel for text prompt:\n{prompt}")
context = tokenizer(prompt, return_tensors='pt',
padding='max_length', truncation=True, max_length=128).to(device)
context = encoder(**context).last_hidden_state
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
in_beat_resolution = in_beat_resolution_dict.get(config.dataset, 4)
midi_decoder_dict = {'remi': 'MidiDecoder4REMI',
'cp': 'MidiDecoder4CP',
'nb': 'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(
vocab=vocab, in_beat_resolution=in_beat_resolution, dataset_name=config.dataset
)
generated_sample = model.generate(
0, generation_length, condition=None, num_target_measures=None,
sampling_method=sampling_method, threshold=threshold,
temperature=temperature, context=context
)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
save_dir.mkdir(parents=True, exist_ok=True)
output_file = save_dir / f"generated.mid"
decoder(generated_sample, output_path=str(output_file))
print(f"Generated file saved at: {output_file}")
def main():
args = get_argument_parser().parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config, model, vocab = load_resources(args.wandb_exp_dir, device)
save_dir = Path(args.output_dir)
config.text_encoder_model = args.text_encoder_model
generate_with_text_prompt(
config,
vocab,
model,
device,
args.prompt,
save_dir,
config.data_params.first_pred_feature,
args.sampling_method,
args.threshold,
args.temperature,
generation_length=args.generate_length,
)
if __name__ == "__main__":
main()