from ast import arg import sys import os from pathlib import Path from multiprocessing import Process,set_start_method import torch import argparse from omegaconf import OmegaConf import json from Amadeus.evaluation_utils import ( wandb_style_config_to_omega_config, prepare_model_and_dataset_from_config, get_best_ckpt_path_and_config, Evaluator ) def get_argument_parser(): parser = argparse.ArgumentParser() parser.add_argument( "-wandb_exp_dir", required=True, type=str, help="wandb experiment directory", ) parser.add_argument( "-generation_type", type=str, choices=('conditioned', 'unconditioned', 'text-conditioned', 'attr-conditioned'), default='unconditioned', help="generation type", ) parser.add_argument( "-attr_list", type=str, default="beat,duration", help="attribute list for attribute-controlled generation", ) parser.add_argument( "-dataset", type=str, help="dataset name, only for conditioned generation", ) parser.add_argument( "-sampling_method", type=str, choices=('top_p', 'top_k', 'min_p'), default='top_p', help="sampling method", ) parser.add_argument( "-threshold", type=float, default=0.99, help="threshold", ) parser.add_argument( "-temperature", type=float, default=1.15, help="temperature", ) parser.add_argument( "-num_samples", type=int, default=30, help="number of samples to generate", ) parser.add_argument( "-num_target_measure", type=int, default=4, help="number of target measures for conditioned generation", ) parser.add_argument( "-choose_selected_tunes", action='store_true', help="generate samples from selected tunes, only for SOD dataset", ) parser.add_argument( "-generate_length", type=int, default=1024, help="length of the generated sequence", ) parser.add_argument( "-num_processes", type=int, default=1, help="number of processes to use", ) parser.add_argument( "-gpu_ids", type=str, default="1,2,3,5", help="comma-separated list of GPU IDs to use (e.g., '0,1,2,3')", ) parser.add_argument( "-prompt", type=str, default="With a rhythm of 100 BPM, this classical piece in 1/4 time signature in the key of Eb major creates a classical mood using String Ensemble, Pizzicato Strings, Tremolo Strings, Trumpet, Timpani.", help="prompt for generation, only used for conditioned generation", ) parser.add_argument( "-prompt_file", type=str, default="dataset/midicaps/train.json", help="file containing prompts for text-conditioned generation", ) return parser def load_resources(wandb_exp_dir, condition_dataset, device): """Load model and dataset resources for a process""" wandb_dir = Path('wandb') ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir) config = OmegaConf.load(config_path) config = wandb_style_config_to_omega_config(config) # Load checkpoint to specified device print("Loading checkpoint from:", ckpt_path) ckpt = torch.load(ckpt_path, map_location=device) print(config) model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path, condition_dataset) model.load_state_dict(ckpt['model'], strict=False) model.to(device) model.eval() torch.compile(model) print("total parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad)) # Prepare dataset for prompts condition_list = [x[1] for x in test_set.data_list] dataset_for_prompt = [] for i in range(len(condition_list)): condition = test_set.get_segments_with_tune_idx(condition_list[i], 0)[0] dataset_for_prompt.append((condition, condition_list[i])) return config, model, dataset_for_prompt, vocab def conditioned_worker(process_idx, gpu_id, args): """Worker process for conditioned generation""" torch.cuda.set_device(gpu_id) device = torch.device(f'cuda:{gpu_id}') # Load resources with proper device config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device) # print(test_set) if args.choose_selected_tunes and test_set.dataset == 'SOD': selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch', "Clarinet Concert in A Major: 2nd Movement, Adagio_orch"] else: selected_tunes = [name for _, name in test_set][:args.num_samples] # Split selected data across processes selected_data = [d for d in test_set if d[1] in selected_tunes] chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes start_idx = 1 end_idx = min(chunk_size, len(selected_data)) data_slice = selected_data[start_idx:end_idx] # Create output directory with process index base_path = Path('wandb') / args.wandb_exp_dir / \ f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}" base_path.mkdir(parents=True, exist_ok=True) evaluator = Evaluator(config, model, data_slice, vocab, device=device) # Process assigned data slice for idx, (tune_in_idx, tune_name) in enumerate(data_slice): batch_dir = base_path / f"process_{process_idx}_batch_{idx}" batch_dir.mkdir(parents=True, exist_ok=True) evaluator.generate_samples_with_prompt( batch_dir, args.num_target_measure, tune_in_idx, tune_name, config.data_params.first_pred_feature, args.sampling_method, args.threshold, args.temperature, generation_length=args.generate_length ) def attr_conditioned_worker(process_idx, gpu_id, args): """Worker process for conditioned generation""" torch.cuda.set_device(gpu_id) device = torch.device(f'cuda:{gpu_id}') # attr_list = "position,duration" attr_list = args.attr_list.split(',') # Load resources with proper device config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device) # print(test_set) if args.choose_selected_tunes and test_set.dataset == 'SOD': selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch', "Clarinet Concert in A Major: 2nd Movement, Adagio_orch"] else: selected_tunes = [name for _, name in test_set][:args.num_samples] # Split selected data across processes selected_data = [d for d in test_set if d[1] in selected_tunes] # chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes # start_idx = 1 # end_idx = min(chunk_size, len(selected_data)) # data_slice = selected_data[start_idx:end_idx] data_slice = selected_data # Create output directory with process index base_path = Path('wandb') / args.wandb_exp_dir / \ f"attrcond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}_attrs{'-'.join(attr_list)}" base_path.mkdir(parents=True, exist_ok=True) evaluator = Evaluator(config, model, data_slice, vocab, device=device) # Process assigned data slice for idx, (tune_in_idx, tune_name) in enumerate(data_slice): batch_dir = base_path batch_dir.mkdir(parents=True, exist_ok=True) evaluator.generate_samples_with_attrCtl( batch_dir, args.num_target_measure, tune_in_idx, tune_name, config.data_params.first_pred_feature, args.sampling_method, args.threshold, args.temperature, generation_length=args.generate_length, attr_list=attr_list ) def unconditioned_worker(process_idx, gpu_id, args, num_samples): """Worker process for unconditioned generation""" torch.cuda.set_device(gpu_id) device = torch.device(f'cuda:{gpu_id}') # Load resources with proper device config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device) # Create output directory with process index base_path = Path('wandb') / args.wandb_exp_dir / \ f"uncond_{args.sampling_method}_t{args.threshold}_temp{args.temperature}" base_path.mkdir(parents=True, exist_ok=True) evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device) # Generate assigned number of samples batch_dir = base_path evaluator.generate_samples_unconditioned( batch_dir, num_samples, config.data_params.first_pred_feature, args.sampling_method, args.threshold, args.temperature, generation_length=args.generate_length, uid=f"{process_idx}" ) def text_conditioned_worker(process_idx, gpu_id, args, num_samples, data_slice): """Worker process for unconditioned generation""" torch.cuda.set_device(gpu_id) device = torch.device(f'cuda:{gpu_id}') # Load resources with proper device config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device) # Create output directory with process index base_path = Path('wandb') / args.wandb_exp_dir / \ f"text_condi_{args.sampling_method}_t{args.threshold}_temp{args.temperature}" base_path.mkdir(parents=True, exist_ok=True) evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device) # Generate assigned number of samples batch_dir = base_path for idx, tune_name in enumerate(data_slice): print(f"Process {process_idx} generating samples for tune: {tune_name}") evaluator.generate_samples_with_text_prompt( batch_dir, tune_name, config.data_params.first_pred_feature, args.sampling_method, args.threshold, args.temperature, generation_length=args.generate_length, uid=f"{process_idx}" ) def main(): # use spawn method for multiprocessing set_start_method('spawn', force=True) args = get_argument_parser().parse_args() gpu_ids = list(map(int, args.gpu_ids.split(','))) # Validate GPU availability if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available") if len(gpu_ids) == 0: raise ValueError("At least one GPU must be specified") # Validate process count if args.num_processes < 1: raise ValueError("Number of processes must be at least 1") if len(gpu_ids) < args.num_processes: print(f"Warning: More processes ({args.num_processes}) than GPUs ({len(gpu_ids)}), some GPUs will be shared") # Prepare data slices for processes processes = [] try: if args.generation_type == 'conditioned': # Prepare selected tunes wandb_dir = Path('wandb') / args.wandb_exp_dir if not wandb_dir.exists(): raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found") for i in range(args.num_processes): gpu_id = gpu_ids[i % len(gpu_ids)] p = Process( target=conditioned_worker, args=(i, gpu_id, args) ) processes.append(p) p.start() elif args.generation_type == 'attr-conditioned': # Prepare selected tunes wandb_dir = Path('wandb') / args.wandb_exp_dir if not wandb_dir.exists(): raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found") for i in range(args.num_processes): gpu_id = gpu_ids[i % len(gpu_ids)] p = Process( target=attr_conditioned_worker, args=(i, gpu_id, args) ) processes.append(p) p.start() elif args.generation_type == 'unconditioned': samples_per_proc = args.num_samples // args.num_processes remainder = args.num_samples % args.num_processes for i in range(args.num_processes): gpu_id = gpu_ids[i % len(gpu_ids)] samples = samples_per_proc + (1 if i < remainder else 0) if samples <= 0: continue p = Process( target=unconditioned_worker, args=(i, gpu_id, args, samples) ) processes.append(p) p.start() elif args.generation_type == 'text-conditioned': samples_per_proc = args.num_samples // args.num_processes remainder = args.num_samples % args.num_processes # Load prompts from file prompt_name_list = [] with open(args.prompt_file, 'r') as f: for line in f: if not line.strip(): continue prompt_data = json.loads(line.strip()) prompt_text = prompt_data['caption'] if prompt_data['test_set'] is True: prompt_name_list.append(prompt_text) print("length of prompt_name_list:", len(prompt_name_list)) if len(prompt_name_list) >= args.num_samples: print(f"Reached the limit of {args.num_samples} prompts.") break for i in range(args.num_processes): gpu_id = gpu_ids[i % len(gpu_ids)] samples = samples_per_proc + (1 if i < remainder else 0) if samples <= 0: continue # Split prompt names across processes start_idx = i * (len(prompt_name_list) // args.num_processes) end_idx = (i + 1) * (len(prompt_name_list) // args.num_processes) data_slice = prompt_name_list[start_idx:end_idx] p = Process( target=text_conditioned_worker, args=(i, gpu_id, args, samples, data_slice) ) processes.append(p) p.start() # Wait for all processes to complete for p in processes: p.join() except Exception as e: print(f"Error in main process: {str(e)}") for p in processes: p.terminate() raise if __name__ == "__main__": main()