first commit

This commit is contained in:
2025-09-08 14:49:28 +08:00
commit 80333dff74
160 changed files with 30655 additions and 0 deletions

210
generate.py Normal file
View File

@ -0,0 +1,210 @@
import torch
from pathlib import Path
import argparse
import json
from collections import defaultdict
from omegaconf import OmegaConf, DictConfig
from transformers import T5Tokenizer, T5EncoderModel
from Amadeus.train_utils import adjust_prediction_order
from Amadeus.evaluation_utils import (
get_dir_from_wandb_by_code,
wandb_style_config_to_omega_config,
)
from Amadeus.symbolic_encoding import decoding_utils, data_utils
from data_representation import vocab_utils
from Amadeus import model_zoo
from Amadeus.symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-wandb_exp_dir",
required=True,
type=str,
help="wandb experiment directory",
)
parser.add_argument(
"-prompt",
required=True,
type=str,
help="text prompt for genvidia-smiration",
)
parser.add_argument(
"-output_dir",
type=str,
default="outputs",
help="directory to save results",
)
parser.add_argument(
"-sampling_method",
type=str,
choices=('top_p', 'top_k'),
default='top_p',
help="sampling method",
)
parser.add_argument(
"-threshold",
type=float,
default=0.99,
help="threshold",
)
parser.add_argument(
"-temperature",
type=float,
default=1.15,
help="temperature",
)
parser.add_argument(
"-generate_length",
type=int,
default=2048,
help="length of the generated sequence",
)
parser.add_argument(
"-text_encoder_model",
type=str,
default='google/flan-t5-large',
help="pretrained text encoder model",
)
return parser
def get_best_ckpt_path_and_config(dir):
if dir is None:
raise ValueError('No such code in wandb_dir')
ckpt_dir = dir / 'files' / 'checkpoints'
config_path = dir / 'files' / 'config.yaml'
# print all files in ckpt_dir
vocab_path = next(ckpt_dir.glob('vocab*'))
# if there is pt file ending with 'last', return it
if len(list(ckpt_dir.glob('*last.pt'))) > 0:
last_ckpt_fn = next(ckpt_dir.glob('*last.pt'))
else:
pt_fns = sorted(list(ckpt_dir.glob('*.pt')), key=lambda fn: int(fn.stem.split('_')[0].replace('iter', '')))
last_ckpt_fn = pt_fns[-1]
return last_ckpt_fn, config_path, vocab_path
def prepare_model_and_dataset_from_config(config: DictConfig, vocab_path:str):
nn_params = config.nn_params
vocab_path = Path(vocab_path)
# print(config)
encoding_scheme = config.nn_params.encoding_scheme
num_features = config.nn_params.num_features
# get vocab
vocab_name = {'remi':'LangTokenVocab', 'cp':'MusicTokenVocabCP', 'nb':'MusicTokenVocabNB'}
selected_vocab_name = vocab_name[encoding_scheme]
vocab = getattr(vocab_utils, selected_vocab_name)(
in_vocab_file_path=vocab_path,
event_data=None,
encoding_scheme=encoding_scheme,
num_features=num_features)
# get proper prediction order according to the encoding scheme and target feature in the config
prediction_order = adjust_prediction_order(encoding_scheme, num_features, config.data_params.first_pred_feature, nn_params)
# Create the Transformer model based on configuration parameters
AmadeusModel = getattr(model_zoo, nn_params.model_name)(
vocab=vocab,
input_length=config.train_params.input_length,
prediction_order=prediction_order,
input_embedder_name=nn_params.input_embedder_name,
main_decoder_name=nn_params.main_decoder_name,
sub_decoder_name=nn_params.sub_decoder_name,
sub_decoder_depth=nn_params.sub_decoder.num_layer if hasattr(nn_params, 'sub_decoder') else 0,
sub_decoder_enricher_use=nn_params.sub_decoder.feature_enricher_use \
if hasattr(nn_params, 'sub_decoder') and hasattr(nn_params.sub_decoder, 'feature_enricher_use') else False,
dim=nn_params.main_decoder.dim_model,
heads=nn_params.main_decoder.num_head,
depth=nn_params.main_decoder.num_layer,
dropout=nn_params.model_dropout,
)
return AmadeusModel, [], vocab
def load_resources(dir, device):
"""Load model and dataset resources"""
dir = Path(dir)
ckpt_path, config_path, vocab_path = get_best_ckpt_path_and_config(
dir
)
config = OmegaConf.load(config_path)
config = wandb_style_config_to_omega_config(config)
ckpt = torch.load(ckpt_path, map_location=device)
model, _, vocab = prepare_model_and_dataset_from_config(config, vocab_path)
model.load_state_dict(ckpt['model'], strict=False)
model.to(device)
model.eval()
torch.compile(model)
print("total parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
return config, model, vocab
def generate_with_text_prompt(config, vocab, model, device, prompt, save_dir,
first_pred_feature, sampling_method, threshold,
temperature, generation_length=1024):
encoding_scheme = config.nn_params.encoding_scheme
tokenizer = T5Tokenizer.from_pretrained(config.text_encoder_model)
encoder = T5EncoderModel.from_pretrained(config.text_encoder_model).to(device)
print(f"Using T5EncoderModel for text prompt:\n{prompt}")
context = tokenizer(prompt, return_tensors='pt',
padding='max_length', truncation=True, max_length=128).to(device)
context = encoder(**context).last_hidden_state
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
in_beat_resolution = in_beat_resolution_dict.get(config.dataset, 4)
midi_decoder_dict = {'remi': 'MidiDecoder4REMI',
'cp': 'MidiDecoder4CP',
'nb': 'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(
vocab=vocab, in_beat_resolution=in_beat_resolution, dataset_name=config.dataset
)
generated_sample = model.generate(
0, generation_length, condition=None, num_target_measures=None,
sampling_method=sampling_method, threshold=threshold,
temperature=temperature, context=context
)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
save_dir.mkdir(parents=True, exist_ok=True)
output_file = save_dir / f"generated.mid"
decoder(generated_sample, output_path=str(output_file))
print(f"Generated file saved at: {output_file}")
def main():
args = get_argument_parser().parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config, model, vocab = load_resources(args.wandb_exp_dir, device)
save_dir = Path(args.output_dir)
config.text_encoder_model = args.text_encoder_model
generate_with_text_prompt(
config,
vocab,
model,
device,
args.prompt,
save_dir,
config.data_params.first_pred_feature,
args.sampling_method,
args.threshold,
args.temperature,
generation_length=args.generate_length,
)
if __name__ == "__main__":
main()