Files
MIDIFoundationModel/generate-batch.py
2025-10-29 17:14:33 +08:00

408 lines
15 KiB
Python

from ast import arg
import sys
import os
from pathlib import Path
from multiprocessing import Process,set_start_method
import torch
import argparse
from omegaconf import OmegaConf
import json
from Amadeus.evaluation_utils import (
wandb_style_config_to_omega_config,
prepare_model_and_dataset_from_config,
get_best_ckpt_path_and_config,
Evaluator
)
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-wandb_exp_dir",
required=True,
type=str,
help="wandb experiment directory",
)
parser.add_argument(
"-generation_type",
type=str,
choices=('conditioned', 'unconditioned', 'text-conditioned', 'attr-conditioned'),
default='unconditioned',
help="generation type",
)
parser.add_argument(
"-attr_list",
type=str,
# default="beat,duration,,instrument,tempo",
default="pitch",
# default='bar,position,velocity,duration,program,tempo,timesig',
help="attribute list for attribute-controlled generation",
)
parser.add_argument(
"-dataset",
type=str,
help="dataset name, only for conditioned generation",
)
parser.add_argument(
"-sampling_method",
type=str,
choices=('top_p', 'top_k', 'min_p'),
default='top_p',
help="sampling method",
)
parser.add_argument(
"-threshold",
type=float,
default=0.99,
help="threshold",
)
parser.add_argument(
"-temperature",
type=float,
default=1.15,
help="temperature",
)
parser.add_argument(
"-num_samples",
type=int,
default=30,
help="number of samples to generate",
)
parser.add_argument(
"-num_target_measure",
type=int,
default=128,
help="number of target measures for conditioned generation",
)
parser.add_argument(
"-choose_selected_tunes",
action='store_true',
help="generate samples from selected tunes, only for SOD dataset",
)
parser.add_argument(
"-generate_length",
type=int,
default=1024,
help="length of the generated sequence",
)
parser.add_argument(
"-num_processes",
type=int,
default=4,
help="number of processes to use",
)
parser.add_argument(
"-gpu_ids",
type=str,
default="0,1,2,3,5",
help="comma-separated list of GPU IDs to use (e.g., '0,1,2,3')",
)
parser.add_argument(
"-prompt",
type=str,
default="With a rhythm of 100 BPM, this classical piece in 1/4 time signature in the key of Eb major creates a classical mood using String Ensemble, Pizzicato Strings, Tremolo Strings, Trumpet, Timpani.",
help="prompt for generation, only used for conditioned generation",
)
parser.add_argument(
"-prompt_file",
type=str,
default="dataset/midicaps/train.json",
help="file containing prompts for text-conditioned generation",
)
return parser
def load_resources(wandb_exp_dir, condition_dataset, device):
"""Load model and dataset resources for a process"""
wandb_dir = Path('wandb')
ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir)
config = OmegaConf.load(config_path)
config = wandb_style_config_to_omega_config(config)
# Load checkpoint to specified device
print("Loading checkpoint from:", ckpt_path)
ckpt = torch.load(ckpt_path, map_location=device)
print(config)
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path, condition_dataset)
model.load_state_dict(ckpt['model'], strict=False)
model.to(device)
model.eval()
torch.compile(model)
print("total parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
# Prepare dataset for prompts
condition_list = [x[1] for x in test_set.data_list]
dataset_for_prompt = []
for i in range(len(condition_list)):
condition = test_set.get_segments_with_tune_idx(condition_list[i], 0)[0]
dataset_for_prompt.append((condition, condition_list[i]))
return config, model, dataset_for_prompt, vocab
def conditioned_worker(process_idx, gpu_id, args):
"""Worker process for conditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# print(test_set)
if args.choose_selected_tunes and test_set.dataset == 'SOD':
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
else:
selected_tunes = [name for _, name in test_set][:args.num_samples]
# Split selected data across processes
selected_data = [d for d in test_set if d[1] in selected_tunes]
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
start_idx = 1
end_idx = min(chunk_size, len(selected_data))
data_slice = selected_data[start_idx:end_idx]
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
base_path.mkdir(parents=True, exist_ok=True)
evaluator = Evaluator(config, model, data_slice, vocab, device=device)
# Process assigned data slice
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
batch_dir = base_path / f"process_{process_idx}_batch_{idx}"
batch_dir.mkdir(parents=True, exist_ok=True)
evaluator.generate_samples_with_prompt(
batch_dir,
args.num_target_measure,
tune_in_idx,
tune_name,
config.data_params.first_pred_feature,
args.sampling_method,
args.threshold,
args.temperature,
generation_length=args.generate_length
)
def attr_conditioned_worker(process_idx, gpu_id, args):
"""Worker process for conditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# attr_list = "position,duration"
attr_list = args.attr_list.split(',')
# Load resources with proper device
config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# print(test_set)
if args.choose_selected_tunes and test_set.dataset == 'SOD':
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
else:
selected_tunes = [name for _, name in test_set][:args.num_samples]
# Split selected data across processes
selected_data = [d for d in test_set if d[1] in selected_tunes]
# chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
# start_idx = 1
# end_idx = min(chunk_size, len(selected_data))
# data_slice = selected_data[start_idx:end_idx]
data_slice = selected_data
print("data_slice length:", len(data_slice))
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"attrcond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}_attrs{'-'.join(attr_list)}"
base_path.mkdir(parents=True, exist_ok=True)
evaluator = Evaluator(config, model, data_slice, vocab, device=device)
# Process assigned data slice
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
batch_dir = base_path
batch_dir.mkdir(parents=True, exist_ok=True)
evaluator.generate_samples_with_attrCtl(
batch_dir,
args.num_target_measure,
tune_in_idx,
tune_name,
config.data_params.first_pred_feature,
args.sampling_method,
args.threshold,
args.temperature,
generation_length=args.generate_length,
attr_list=attr_list
)
def unconditioned_worker(process_idx, gpu_id, args, num_samples):
"""Worker process for unconditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"uncond_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
base_path.mkdir(parents=True, exist_ok=True)
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device)
# Generate assigned number of samples
batch_dir = base_path
evaluator.generate_samples_unconditioned(
batch_dir,
num_samples,
config.data_params.first_pred_feature,
args.sampling_method,
args.threshold,
args.temperature,
generation_length=args.generate_length,
uid=f"{process_idx}"
)
def text_conditioned_worker(process_idx, gpu_id, args, num_samples, data_slice):
"""Worker process for unconditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"text_condi_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
base_path.mkdir(parents=True, exist_ok=True)
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device)
# Generate assigned number of samples
batch_dir = base_path
for idx, tune_name in enumerate(data_slice):
print(f"Process {process_idx} generating samples for tune: {tune_name}")
evaluator.generate_samples_with_text_prompt(
batch_dir,
tune_name,
config.data_params.first_pred_feature,
args.sampling_method,
args.threshold,
args.temperature,
generation_length=args.generate_length,
uid=f"{process_idx}"
)
def main():
# use spawn method for multiprocessing
set_start_method('spawn', force=True)
args = get_argument_parser().parse_args()
gpu_ids = list(map(int, args.gpu_ids.split(',')))
# Validate GPU availability
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
if len(gpu_ids) == 0:
raise ValueError("At least one GPU must be specified")
# Validate process count
if args.num_processes < 1:
raise ValueError("Number of processes must be at least 1")
if len(gpu_ids) < args.num_processes:
print(f"Warning: More processes ({args.num_processes}) than GPUs ({len(gpu_ids)}), some GPUs will be shared")
# Prepare data slices for processes
processes = []
try:
if args.generation_type == 'conditioned':
# Prepare selected tunes
wandb_dir = Path('wandb') / args.wandb_exp_dir
if not wandb_dir.exists():
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
for i in range(args.num_processes):
gpu_id = gpu_ids[i % len(gpu_ids)]
p = Process(
target=conditioned_worker,
args=(i, gpu_id, args)
)
processes.append(p)
p.start()
elif args.generation_type == 'attr-conditioned':
# Prepare selected tunes
wandb_dir = Path('wandb') / args.wandb_exp_dir
if not wandb_dir.exists():
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
for i in range(args.num_processes):
gpu_id = gpu_ids[i % len(gpu_ids)]
p = Process(
target=attr_conditioned_worker,
args=(i, gpu_id, args)
)
processes.append(p)
p.start()
elif args.generation_type == 'unconditioned':
samples_per_proc = args.num_samples // args.num_processes
remainder = args.num_samples % args.num_processes
for i in range(args.num_processes):
gpu_id = gpu_ids[i % len(gpu_ids)]
samples = samples_per_proc + (1 if i < remainder else 0)
if samples <= 0:
continue
p = Process(
target=unconditioned_worker,
args=(i, gpu_id, args, samples)
)
processes.append(p)
p.start()
elif args.generation_type == 'text-conditioned':
samples_per_proc = args.num_samples // args.num_processes
remainder = args.num_samples % args.num_processes
# Load prompts from file
prompt_name_list = []
with open(args.prompt_file, 'r') as f:
for line in f:
if not line.strip():
continue
prompt_data = json.loads(line.strip())
prompt_text = prompt_data['caption']
if prompt_data['test_set'] is True:
prompt_name_list.append(prompt_text)
print("length of prompt_name_list:", len(prompt_name_list))
if len(prompt_name_list) >= args.num_samples:
print(f"Reached the limit of {args.num_samples} prompts.")
break
for i in range(args.num_processes):
gpu_id = gpu_ids[i % len(gpu_ids)]
samples = samples_per_proc + (1 if i < remainder else 0)
if samples <= 0:
continue
# Split prompt names across processes
start_idx = i * (len(prompt_name_list) // args.num_processes)
end_idx = (i + 1) * (len(prompt_name_list) // args.num_processes)
data_slice = prompt_name_list[start_idx:end_idx]
p = Process(
target=text_conditioned_worker,
args=(i, gpu_id, args, samples, data_slice)
)
processes.append(p)
p.start()
# Wait for all processes to complete
for p in processes:
p.join()
except Exception as e:
print(f"Error in main process: {str(e)}")
for p in processes:
p.terminate()
raise
if __name__ == "__main__":
main()