336 lines
12 KiB
Python
336 lines
12 KiB
Python
import sys
|
|
import os
|
|
from pathlib import Path
|
|
from multiprocessing import Process,set_start_method
|
|
import torch
|
|
import argparse
|
|
from omegaconf import OmegaConf
|
|
import json
|
|
|
|
from Amadeus.evaluation_utils import (
|
|
wandb_style_config_to_omega_config,
|
|
prepare_model_and_dataset_from_config,
|
|
get_best_ckpt_path_and_config,
|
|
Evaluator
|
|
)
|
|
|
|
def get_argument_parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"-wandb_exp_dir",
|
|
required=True,
|
|
type=str,
|
|
help="wandb experiment directory",
|
|
)
|
|
parser.add_argument(
|
|
"-generation_type",
|
|
type=str,
|
|
choices=('conditioned', 'unconditioned', 'text-conditioned'),
|
|
default='unconditioned',
|
|
help="generation type",
|
|
)
|
|
parser.add_argument(
|
|
"-sampling_method",
|
|
type=str,
|
|
choices=('top_p', 'top_k'),
|
|
default='top_p',
|
|
help="sampling method",
|
|
)
|
|
parser.add_argument(
|
|
"-threshold",
|
|
type=float,
|
|
default=0.99,
|
|
help="threshold",
|
|
)
|
|
parser.add_argument(
|
|
"-temperature",
|
|
type=float,
|
|
default=1.15,
|
|
help="temperature",
|
|
)
|
|
parser.add_argument(
|
|
"-num_samples",
|
|
type=int,
|
|
default=30,
|
|
help="number of samples to generate",
|
|
)
|
|
parser.add_argument(
|
|
"-num_target_measure",
|
|
type=int,
|
|
default=4,
|
|
help="number of target measures for conditioned generation",
|
|
)
|
|
parser.add_argument(
|
|
"-choose_selected_tunes",
|
|
action='store_true',
|
|
help="generate samples from selected tunes, only for SOD dataset",
|
|
)
|
|
parser.add_argument(
|
|
"-generate_length",
|
|
type=int,
|
|
default=1024,
|
|
help="length of the generated sequence",
|
|
)
|
|
parser.add_argument(
|
|
"-num_processes",
|
|
type=int,
|
|
default=4,
|
|
help="number of processes to use",
|
|
)
|
|
parser.add_argument(
|
|
"-gpu_ids",
|
|
type=str,
|
|
default="1,2,3,5",
|
|
help="comma-separated list of GPU IDs to use (e.g., '0,1,2,3')",
|
|
)
|
|
parser.add_argument(
|
|
"-prompt",
|
|
type=str,
|
|
default="With a rhythm of 100 BPM, this classical piece in 1/4 time signature in the key of Eb major creates a classical mood using String Ensemble, Pizzicato Strings, Tremolo Strings, Trumpet, Timpani.",
|
|
help="prompt for generation, only used for conditioned generation",
|
|
)
|
|
parser.add_argument(
|
|
"-prompt_file",
|
|
type=str,
|
|
default="dataset/midicaps/train.json",
|
|
help="file containing prompts for text-conditioned generation",
|
|
)
|
|
return parser
|
|
|
|
def load_resources(wandb_exp_dir, device):
|
|
"""Load model and dataset resources for a process"""
|
|
wandb_dir = Path('wandb')
|
|
ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir)
|
|
config = OmegaConf.load(config_path)
|
|
config = wandb_style_config_to_omega_config(config)
|
|
|
|
# Load checkpoint to specified device
|
|
ckpt = torch.load(ckpt_path, map_location=device)
|
|
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path)
|
|
model.load_state_dict(ckpt['model'], strict=False)
|
|
model.to(device)
|
|
model.eval()
|
|
torch.compile(model)
|
|
print("total parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
|
|
|
|
# Prepare dataset for prompts
|
|
condition_list = [x[1] for x in test_set.data_list]
|
|
dataset_for_prompt = []
|
|
for i in range(len(condition_list)):
|
|
condition = test_set.get_segments_with_tune_idx(condition_list[i], 0)[0]
|
|
dataset_for_prompt.append((condition, condition_list[i]))
|
|
|
|
return config, model, dataset_for_prompt, vocab
|
|
|
|
def conditioned_worker(process_idx, gpu_id, args, data_slice):
|
|
"""Worker process for conditioned generation"""
|
|
torch.cuda.set_device(gpu_id)
|
|
device = torch.device(f'cuda:{gpu_id}')
|
|
|
|
# Load resources with proper device
|
|
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
|
|
|
# Create output directory with process index
|
|
base_path = Path('wandb') / args.wandb_exp_dir / \
|
|
f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
|
|
base_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device)
|
|
|
|
# Process assigned data slice
|
|
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
|
|
batch_dir = base_path / f"process_{process_idx}_batch_{idx}"
|
|
batch_dir.mkdir(parents=True, exist_ok=True)
|
|
evaluator.generate_samples_with_prompt(
|
|
batch_dir,
|
|
args.num_target_measure,
|
|
tune_in_idx,
|
|
tune_name,
|
|
config.data_params.first_pred_feature,
|
|
args.sampling_method,
|
|
args.threshold,
|
|
args.temperature,
|
|
generation_length=args.generate_length
|
|
)
|
|
|
|
def unconditioned_worker(process_idx, gpu_id, args, num_samples):
|
|
"""Worker process for unconditioned generation"""
|
|
torch.cuda.set_device(gpu_id)
|
|
device = torch.device(f'cuda:{gpu_id}')
|
|
|
|
# Load resources with proper device
|
|
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
|
|
|
# Create output directory with process index
|
|
base_path = Path('wandb') / args.wandb_exp_dir / \
|
|
f"uncond_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
|
|
base_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device)
|
|
|
|
# Generate assigned number of samples
|
|
batch_dir = base_path
|
|
evaluator.generate_samples_unconditioned(
|
|
batch_dir,
|
|
num_samples,
|
|
config.data_params.first_pred_feature,
|
|
args.sampling_method,
|
|
args.threshold,
|
|
args.temperature,
|
|
generation_length=args.generate_length,
|
|
uid=f"{process_idx}"
|
|
)
|
|
def text_conditioned_worker(process_idx, gpu_id, args, num_samples, data_slice):
|
|
"""Worker process for unconditioned generation"""
|
|
torch.cuda.set_device(gpu_id)
|
|
device = torch.device(f'cuda:{gpu_id}')
|
|
|
|
# Load resources with proper device
|
|
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
|
|
|
# Create output directory with process index
|
|
base_path = Path('wandb') / args.wandb_exp_dir / \
|
|
f"text_condi_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
|
|
base_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device)
|
|
|
|
# Generate assigned number of samples
|
|
batch_dir = base_path
|
|
for idx, tune_name in enumerate(data_slice):
|
|
print(f"Process {process_idx} generating samples for tune: {tune_name}")
|
|
evaluator.generate_samples_with_text_prompt(
|
|
batch_dir,
|
|
tune_name,
|
|
config.data_params.first_pred_feature,
|
|
args.sampling_method,
|
|
args.threshold,
|
|
args.temperature,
|
|
generation_length=args.generate_length,
|
|
uid=f"{process_idx}"
|
|
)
|
|
def main():
|
|
# use spawn method for multiprocessing
|
|
set_start_method('spawn', force=True)
|
|
args = get_argument_parser().parse_args()
|
|
gpu_ids = list(map(int, args.gpu_ids.split(',')))
|
|
|
|
# Validate GPU availability
|
|
if not torch.cuda.is_available():
|
|
raise RuntimeError("CUDA is not available")
|
|
if len(gpu_ids) == 0:
|
|
raise ValueError("At least one GPU must be specified")
|
|
|
|
# Validate process count
|
|
if args.num_processes < 1:
|
|
raise ValueError("Number of processes must be at least 1")
|
|
if len(gpu_ids) < args.num_processes:
|
|
print(f"Warning: More processes ({args.num_processes}) than GPUs ({len(gpu_ids)}), some GPUs will be shared")
|
|
|
|
# Prepare data slices for processes
|
|
processes = []
|
|
try:
|
|
if args.generation_type == 'conditioned':
|
|
# Prepare selected tunes
|
|
wandb_dir = Path('wandb') / args.wandb_exp_dir
|
|
if not wandb_dir.exists():
|
|
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
|
|
|
|
# Load test set to get selected tunes (dummy load to get dataset info)
|
|
dummy_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
_, test_set, _ = prepare_model_and_dataset_from_config(
|
|
wandb_dir / "files" / "config.yaml",
|
|
wandb_dir / "files" / "metadata.json",
|
|
wandb_dir / "files" / "vocab.json"
|
|
)
|
|
|
|
if args.choose_selected_tunes and test_set.dataset == 'SOD':
|
|
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
|
|
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
|
|
else:
|
|
selected_tunes = [name for _, name in test_set.data_list][:args.num_samples]
|
|
|
|
# Split selected data across processes
|
|
selected_data = [d for d in test_set.data_list if d[1] in selected_tunes]
|
|
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
|
|
|
|
for i in range(args.num_processes):
|
|
start_idx = i * chunk_size
|
|
end_idx = min((i+1)*chunk_size, len(selected_data))
|
|
data_slice = selected_data[start_idx:end_idx]
|
|
|
|
if not data_slice:
|
|
continue
|
|
|
|
gpu_id = gpu_ids[i % len(gpu_ids)]
|
|
p = Process(
|
|
target=conditioned_worker,
|
|
args=(i, gpu_id, args, data_slice)
|
|
)
|
|
processes.append(p)
|
|
p.start()
|
|
|
|
elif args.generation_type == 'unconditioned':
|
|
samples_per_proc = args.num_samples // args.num_processes
|
|
remainder = args.num_samples % args.num_processes
|
|
|
|
for i in range(args.num_processes):
|
|
gpu_id = gpu_ids[i % len(gpu_ids)]
|
|
samples = samples_per_proc + (1 if i < remainder else 0)
|
|
|
|
if samples <= 0:
|
|
continue
|
|
|
|
p = Process(
|
|
target=unconditioned_worker,
|
|
args=(i, gpu_id, args, samples)
|
|
)
|
|
processes.append(p)
|
|
p.start()
|
|
elif args.generation_type == 'text-conditioned':
|
|
samples_per_proc = args.num_samples // args.num_processes
|
|
remainder = args.num_samples % args.num_processes
|
|
# Load prompts from file
|
|
prompt_name_list = []
|
|
with open(args.prompt_file, 'r') as f:
|
|
for line in f:
|
|
if not line.strip():
|
|
continue
|
|
prompt_data = json.loads(line.strip())
|
|
prompt_text = prompt_data['caption']
|
|
if prompt_data['test_set'] is True:
|
|
prompt_name_list.append(prompt_text)
|
|
print("length of prompt_name_list:", len(prompt_name_list))
|
|
if len(prompt_name_list) >= args.num_samples:
|
|
print(f"Reached the limit of {args.num_samples} prompts.")
|
|
break
|
|
for i in range(args.num_processes):
|
|
gpu_id = gpu_ids[i % len(gpu_ids)]
|
|
samples = samples_per_proc + (1 if i < remainder else 0)
|
|
|
|
if samples <= 0:
|
|
continue
|
|
|
|
# Split prompt names across processes
|
|
start_idx = i * (len(prompt_name_list) // args.num_processes)
|
|
end_idx = (i + 1) * (len(prompt_name_list) // args.num_processes)
|
|
data_slice = prompt_name_list[start_idx:end_idx]
|
|
|
|
p = Process(
|
|
target=text_conditioned_worker,
|
|
args=(i, gpu_id, args, samples, data_slice)
|
|
)
|
|
processes.append(p)
|
|
p.start()
|
|
# Wait for all processes to complete
|
|
for p in processes:
|
|
p.join()
|
|
|
|
except Exception as e:
|
|
print(f"Error in main process: {str(e)}")
|
|
for p in processes:
|
|
p.terminate()
|
|
raise
|
|
|
|
if __name__ == "__main__":
|
|
main() |