404 lines
16 KiB
Python
404 lines
16 KiB
Python
import sys
|
|
import os
|
|
from pathlib import Path
|
|
from multiprocessing import Process,set_start_method
|
|
import torch
|
|
import argparse
|
|
from omegaconf import OmegaConf
|
|
import json
|
|
from collections import defaultdict
|
|
|
|
from Amadeus.evaluation_utils import (
|
|
wandb_style_config_to_omega_config,
|
|
prepare_model_and_dataset_from_config,
|
|
get_best_ckpt_path_and_config,
|
|
Evaluator
|
|
)
|
|
from transformers import T5Tokenizer, T5EncoderModel
|
|
|
|
from Amadeus import model_zoo
|
|
from Amadeus.symbolic_encoding import data_utils
|
|
from Amadeus.model_zoo import AmadeusModel
|
|
from Amadeus.symbolic_encoding.data_utils import TuneCompiler
|
|
from Amadeus.symbolic_encoding.compile_utils import shift_and_pad
|
|
from Amadeus.symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor
|
|
from Amadeus.symbolic_encoding import decoding_utils
|
|
from Amadeus.train_utils import adjust_prediction_order
|
|
from data_representation import vocab_utils
|
|
from data_representation.vocab_utils import LangTokenVocab
|
|
|
|
|
|
def get_argument_parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"-wandb_exp_dir",
|
|
required=True,
|
|
type=str,
|
|
help="wandb experiment directory",
|
|
)
|
|
parser.add_argument(
|
|
"-generation_type",
|
|
type=str,
|
|
choices=('conditioned', 'unconditioned', 'text-conditioned'),
|
|
default='unconditioned',
|
|
help="generation type",
|
|
)
|
|
parser.add_argument(
|
|
"-sampling_method",
|
|
type=str,
|
|
choices=('top_p', 'top_k'),
|
|
default='top_p',
|
|
help="sampling method",
|
|
)
|
|
parser.add_argument(
|
|
"-threshold",
|
|
type=float,
|
|
default=0.99,
|
|
help="threshold",
|
|
)
|
|
parser.add_argument(
|
|
"-temperature",
|
|
type=float,
|
|
default=1.15,
|
|
help="temperature",
|
|
)
|
|
parser.add_argument(
|
|
"-num_samples",
|
|
type=int,
|
|
default=30,
|
|
help="number of samples to generate",
|
|
)
|
|
parser.add_argument(
|
|
"-num_target_measure",
|
|
type=int,
|
|
default=4,
|
|
help="number of target measures for conditioned generation",
|
|
)
|
|
parser.add_argument(
|
|
"-choose_selected_tunes",
|
|
action='store_true',
|
|
help="generate samples from selected tunes, only for SOD dataset",
|
|
)
|
|
parser.add_argument(
|
|
"-generate_length",
|
|
type=int,
|
|
default=1024,
|
|
help="length of the generated sequence",
|
|
)
|
|
parser.add_argument(
|
|
"-num_processes",
|
|
type=int,
|
|
default=2,
|
|
help="number of processes to use",
|
|
)
|
|
parser.add_argument(
|
|
"-gpu_ids",
|
|
type=str,
|
|
default="0,5",
|
|
help="comma-separated list of GPU IDs to use (e.g., '0,1,2,3')",
|
|
)
|
|
parser.add_argument(
|
|
"-prompt",
|
|
type=str,
|
|
default="With a rhythm of 100 BPM, this classical piece in 1/4 time signature in the key of Eb major creates a classical mood using String Ensemble, Pizzicato Strings, Tremolo Strings, Trumpet, Timpani.",
|
|
help="prompt for generation, only used for conditioned generation",
|
|
)
|
|
parser.add_argument(
|
|
"-prompt_file",
|
|
type=str,
|
|
default="dataset/midicaps/train.json",
|
|
help="file containing prompts for text-conditioned generation",
|
|
)
|
|
return parser
|
|
|
|
def load_resources(wandb_exp_dir, device):
|
|
"""Load model and dataset resources for a process"""
|
|
wandb_dir = Path('wandb')
|
|
ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir)
|
|
config = OmegaConf.load(config_path)
|
|
config = wandb_style_config_to_omega_config(config)
|
|
|
|
# Load checkpoint to specified device
|
|
ckpt = torch.load(ckpt_path, map_location=device)
|
|
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path)
|
|
model.load_state_dict(ckpt['model'], strict=False)
|
|
model.to(device)
|
|
model.eval()
|
|
torch.compile(model)
|
|
print("total parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
|
|
|
|
# Prepare dataset for prompts
|
|
condition_list = [x[1] for x in test_set.data_list]
|
|
dataset_for_prompt = []
|
|
for i in range(len(condition_list)):
|
|
condition = test_set.get_segments_with_tune_idx(condition_list[i], 0)[0]
|
|
dataset_for_prompt.append((condition, condition_list[i]))
|
|
|
|
return config, model, dataset_for_prompt, vocab
|
|
|
|
def conditioned_worker(process_idx, gpu_id, args, data_slice):
|
|
"""Worker process for conditioned generation"""
|
|
torch.cuda.set_device(gpu_id)
|
|
device = torch.device(f'cuda:{gpu_id}')
|
|
|
|
# Load resources with proper device
|
|
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
|
|
|
# Create output directory with process index
|
|
base_path = Path('wandb') / args.wandb_exp_dir / \
|
|
f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
|
|
base_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device)
|
|
|
|
# Process assigned data slice
|
|
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
|
|
batch_dir = base_path / f"process_{process_idx}_batch_{idx}"
|
|
batch_dir.mkdir(parents=True, exist_ok=True)
|
|
evaluator.generate_samples_with_prompt(
|
|
batch_dir,
|
|
args.num_target_measure,
|
|
tune_in_idx,
|
|
tune_name,
|
|
config.data_params.first_pred_feature,
|
|
args.sampling_method,
|
|
args.threshold,
|
|
args.temperature,
|
|
generation_length=args.generate_length
|
|
)
|
|
def generate_samples_unconditioned(config, vocab, model, device,save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
|
|
encoding_scheme = config.nn_params.encoding_scheme
|
|
|
|
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
|
try:
|
|
in_beat_resolution = in_beat_resolution_dict[config.dataset]
|
|
except KeyError:
|
|
in_beat_resolution = 4 # Default resolution if dataset is not found
|
|
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
|
decoder_name = midi_decoder_dict[encoding_scheme]
|
|
decoder = getattr(decoding_utils, decoder_name)(vocab=vocab, in_beat_resolution=in_beat_resolution, dataset_name=config.dataset)
|
|
|
|
for i in range(num_samples):
|
|
generated_sample = model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
|
if encoding_scheme == 'nb':
|
|
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
|
decoder(generated_sample, output_path=str(save_dir / f"{uid}_{i}.mid"))
|
|
|
|
def generate_samples_with_text_prompt(config, vocab, model, device, save_dir, prompt, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
|
|
encoding_scheme = config.nn_params.encoding_scheme
|
|
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-large')
|
|
encoder = T5EncoderModel.from_pretrained('google/flan-t5-large').to(device)
|
|
print(f"Using T5EncoderModel for text prompt: {prompt}")
|
|
context = tokenizer(prompt, return_tensors='pt', padding='max_length', truncation=True, max_length=128).to(device)
|
|
context = encoder(**context).last_hidden_state
|
|
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
|
try:
|
|
in_beat_resolution = in_beat_resolution_dict[config.dataset]
|
|
except KeyError:
|
|
in_beat_resolution = 4 # Default resolution if dataset is not found
|
|
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
|
decoder_name = midi_decoder_dict[encoding_scheme]
|
|
decoder = getattr(decoding_utils, decoder_name)(vocab=vocab, in_beat_resolution=in_beat_resolution, dataset_name=config.dataset)
|
|
|
|
generated_sample = model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
|
|
if encoding_scheme == 'nb':
|
|
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
|
# Open the jsonl file and count the number of lines to determine the current index
|
|
jsonl_path = save_dir / "name2prompt.jsonl"
|
|
if jsonl_path.exists():
|
|
with open(jsonl_path, 'r') as f:
|
|
current_idx = sum(1 for _ in f)
|
|
else:
|
|
current_idx = 0
|
|
|
|
name = f"prompt_{current_idx}"
|
|
name2prompt_dict = defaultdict(list)
|
|
name2prompt_dict[name].append(prompt)
|
|
with open(jsonl_path, 'a') as f:
|
|
f.write(json.dumps(name2prompt_dict) + '\n')
|
|
decoder(generated_sample, output_path=str(save_dir / f"{name}_{uid}.mid"))
|
|
|
|
def unconditioned_worker(process_idx, gpu_id, args, num_samples):
|
|
"""Worker process for unconditioned generation"""
|
|
torch.cuda.set_device(gpu_id)
|
|
device = torch.device(f'cuda:{gpu_id}')
|
|
|
|
# Load resources with proper device
|
|
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
|
|
|
# Create output directory with process index
|
|
base_path = Path('wandb') / args.wandb_exp_dir / \
|
|
f"uncond_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
|
|
base_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Generate assigned number of samples
|
|
batch_dir = base_path
|
|
generate_samples_unconditioned(
|
|
config,
|
|
vocab,
|
|
model,
|
|
batch_dir,
|
|
num_samples,
|
|
config.data_params.first_pred_feature,
|
|
args.sampling_method,
|
|
args.threshold,
|
|
args.temperature,
|
|
generation_length=args.generate_length,
|
|
uid=f"{process_idx}"
|
|
)
|
|
def text_conditioned_worker(process_idx, gpu_id, args, num_samples, data_slice):
|
|
"""Worker process for unconditioned generation"""
|
|
torch.cuda.set_device(gpu_id)
|
|
device = torch.device(f'cuda:{gpu_id}')
|
|
|
|
# Load resources with proper device
|
|
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
|
|
|
# Create output directory with process index
|
|
base_path = Path('wandb') / args.wandb_exp_dir / \
|
|
f"text_condi_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
|
|
base_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Generate assigned number of samples
|
|
batch_dir = base_path
|
|
for idx, tune_name in enumerate(data_slice):
|
|
print(f"Process {process_idx} generating samples for tune: {tune_name}")
|
|
generate_samples_with_text_prompt(
|
|
config,
|
|
vocab,
|
|
model,
|
|
device,
|
|
batch_dir,
|
|
prompt=tune_name,
|
|
first_pred_feature=config.data_params.first_pred_feature,
|
|
sampling_method=args.sampling_method,
|
|
threshold=args.threshold,
|
|
temperature=args.temperature,
|
|
generation_length=args.generate_length,
|
|
uid=f"{process_idx}_{idx}"
|
|
)
|
|
def main():
|
|
# use spawn method for multiprocessing
|
|
set_start_method('spawn', force=True)
|
|
args = get_argument_parser().parse_args()
|
|
gpu_ids = list(map(int, args.gpu_ids.split(',')))
|
|
|
|
# Validate GPU availability
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError("CUDA is not available")
|
|
if len(gpu_ids) == 0:
|
|
raise ValueError("At least one GPU must be specified")
|
|
|
|
# Validate process count
|
|
if args.num_processes < 1:
|
|
raise ValueError("Number of processes must be at least 1")
|
|
if len(gpu_ids) < args.num_processes:
|
|
print(f"Warning: More processes ({args.num_processes}) than GPUs ({len(gpu_ids)}), some GPUs will be shared")
|
|
|
|
# Prepare data slices for processes
|
|
processes = []
|
|
try:
|
|
if args.generation_type == 'conditioned':
|
|
# Prepare selected tunes
|
|
wandb_dir = Path('wandb') / args.wandb_exp_dir
|
|
if not wandb_dir.exists():
|
|
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
|
|
|
|
# Load test set to get selected tunes (dummy load to get dataset info)
|
|
dummy_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
_, test_set, _ = prepare_model_and_dataset_from_config(
|
|
wandb_dir / "files" / "config.yaml",
|
|
wandb_dir / "files" / "metadata.json",
|
|
wandb_dir / "files" / "vocab.json"
|
|
)
|
|
|
|
if args.choose_selected_tunes and test_set.dataset == 'SOD':
|
|
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
|
|
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
|
|
else:
|
|
selected_tunes = [name for _, name in test_set.data_list][:args.num_samples]
|
|
|
|
# Split selected data across processes
|
|
selected_data = [d for d in test_set.data_list if d[1] in selected_tunes]
|
|
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
|
|
|
|
for i in range(args.num_processes):
|
|
start_idx = i * chunk_size
|
|
end_idx = min((i+1)*chunk_size, len(selected_data))
|
|
data_slice = selected_data[start_idx:end_idx]
|
|
|
|
if not data_slice:
|
|
continue
|
|
|
|
gpu_id = gpu_ids[i % len(gpu_ids)]
|
|
p = Process(
|
|
target=conditioned_worker,
|
|
args=(i, gpu_id, args, data_slice)
|
|
)
|
|
processes.append(p)
|
|
p.start()
|
|
|
|
elif args.generation_type == 'unconditioned':
|
|
samples_per_proc = args.num_samples // args.num_processes
|
|
remainder = args.num_samples % args.num_processes
|
|
|
|
for i in range(args.num_processes):
|
|
gpu_id = gpu_ids[i % len(gpu_ids)]
|
|
samples = samples_per_proc + (1 if i < remainder else 0)
|
|
|
|
if samples <= 0:
|
|
continue
|
|
|
|
p = Process(
|
|
target=unconditioned_worker,
|
|
args=(i, gpu_id, args, samples)
|
|
)
|
|
processes.append(p)
|
|
p.start()
|
|
elif args.generation_type == 'text-conditioned':
|
|
samples_per_proc = args.num_samples // args.num_processes
|
|
remainder = args.num_samples % args.num_processes
|
|
# Load prompts from file
|
|
prompt_name_list = []
|
|
with open(args.prompt_file, 'r') as f:
|
|
for line in f:
|
|
if not line.strip():
|
|
continue
|
|
prompt_data = json.loads(line.strip())
|
|
prompt_text = prompt_data['caption']
|
|
if prompt_data['test_set'] is True:
|
|
prompt_name_list.append(prompt_text)
|
|
print("length of prompt_name_list:", len(prompt_name_list))
|
|
if len(prompt_name_list) >= args.num_samples:
|
|
print(f"Reached the limit of {args.num_samples} prompts.")
|
|
break
|
|
for i in range(args.num_processes):
|
|
gpu_id = gpu_ids[i % len(gpu_ids)]
|
|
samples = samples_per_proc + (1 if i < remainder else 0)
|
|
|
|
if samples <= 0:
|
|
continue
|
|
|
|
# Split prompt names across processes
|
|
start_idx = i * (len(prompt_name_list) // args.num_processes)
|
|
end_idx = (i + 1) * (len(prompt_name_list) // args.num_processes)
|
|
data_slice = prompt_name_list[start_idx:end_idx]
|
|
|
|
p = Process(
|
|
target=text_conditioned_worker,
|
|
args=(i, gpu_id, args, samples, data_slice)
|
|
)
|
|
processes.append(p)
|
|
p.start()
|
|
# Wait for all processes to complete
|
|
for p in processes:
|
|
p.join()
|
|
|
|
except Exception as e:
|
|
print(f"Error in main process: {str(e)}")
|
|
for p in processes:
|
|
p.terminate()
|
|
raise
|
|
|
|
if __name__ == "__main__":
|
|
main() |