1021 add flexable attr control
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
from ast import arg
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
@ -25,14 +26,25 @@ def get_argument_parser():
|
||||
parser.add_argument(
|
||||
"-generation_type",
|
||||
type=str,
|
||||
choices=('conditioned', 'unconditioned', 'text-conditioned'),
|
||||
choices=('conditioned', 'unconditioned', 'text-conditioned', 'attr-conditioned'),
|
||||
default='unconditioned',
|
||||
help="generation type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-attr_list",
|
||||
type=str,
|
||||
default="beat,duration",
|
||||
help="attribute list for attribute-controlled generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-dataset",
|
||||
type=str,
|
||||
help="dataset name, only for conditioned generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-sampling_method",
|
||||
type=str,
|
||||
choices=('top_p', 'top_k'),
|
||||
choices=('top_p', 'top_k', 'min_p'),
|
||||
default='top_p',
|
||||
help="sampling method",
|
||||
)
|
||||
@ -74,7 +86,7 @@ def get_argument_parser():
|
||||
parser.add_argument(
|
||||
"-num_processes",
|
||||
type=int,
|
||||
default=4,
|
||||
default=1,
|
||||
help="number of processes to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -97,7 +109,7 @@ def get_argument_parser():
|
||||
)
|
||||
return parser
|
||||
|
||||
def load_resources(wandb_exp_dir, device):
|
||||
def load_resources(wandb_exp_dir, condition_dataset, device):
|
||||
"""Load model and dataset resources for a process"""
|
||||
wandb_dir = Path('wandb')
|
||||
ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir)
|
||||
@ -107,7 +119,8 @@ def load_resources(wandb_exp_dir, device):
|
||||
# Load checkpoint to specified device
|
||||
print("Loading checkpoint from:", ckpt_path)
|
||||
ckpt = torch.load(ckpt_path, map_location=device)
|
||||
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path)
|
||||
print(config)
|
||||
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path, condition_dataset)
|
||||
model.load_state_dict(ckpt['model'], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
@ -123,20 +136,33 @@ def load_resources(wandb_exp_dir, device):
|
||||
|
||||
return config, model, dataset_for_prompt, vocab
|
||||
|
||||
def conditioned_worker(process_idx, gpu_id, args, data_slice):
|
||||
def conditioned_worker(process_idx, gpu_id, args):
|
||||
"""Worker process for conditioned generation"""
|
||||
torch.cuda.set_device(gpu_id)
|
||||
device = torch.device(f'cuda:{gpu_id}')
|
||||
|
||||
# Load resources with proper device
|
||||
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
||||
config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
|
||||
# print(test_set)
|
||||
if args.choose_selected_tunes and test_set.dataset == 'SOD':
|
||||
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
|
||||
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
|
||||
else:
|
||||
selected_tunes = [name for _, name in test_set][:args.num_samples]
|
||||
|
||||
# Split selected data across processes
|
||||
selected_data = [d for d in test_set if d[1] in selected_tunes]
|
||||
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
|
||||
start_idx = 1
|
||||
end_idx = min(chunk_size, len(selected_data))
|
||||
data_slice = selected_data[start_idx:end_idx]
|
||||
|
||||
# Create output directory with process index
|
||||
base_path = Path('wandb') / args.wandb_exp_dir / \
|
||||
f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device)
|
||||
evaluator = Evaluator(config, model, data_slice, vocab, device=device)
|
||||
|
||||
# Process assigned data slice
|
||||
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
|
||||
@ -154,13 +180,62 @@ def conditioned_worker(process_idx, gpu_id, args, data_slice):
|
||||
generation_length=args.generate_length
|
||||
)
|
||||
|
||||
def attr_conditioned_worker(process_idx, gpu_id, args):
|
||||
"""Worker process for conditioned generation"""
|
||||
torch.cuda.set_device(gpu_id)
|
||||
device = torch.device(f'cuda:{gpu_id}')
|
||||
# attr_list = "position,duration"
|
||||
attr_list = args.attr_list.split(',')
|
||||
|
||||
# Load resources with proper device
|
||||
config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
|
||||
# print(test_set)
|
||||
if args.choose_selected_tunes and test_set.dataset == 'SOD':
|
||||
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
|
||||
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
|
||||
else:
|
||||
selected_tunes = [name for _, name in test_set][:args.num_samples]
|
||||
|
||||
# Split selected data across processes
|
||||
selected_data = [d for d in test_set if d[1] in selected_tunes]
|
||||
# chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
|
||||
# start_idx = 1
|
||||
# end_idx = min(chunk_size, len(selected_data))
|
||||
# data_slice = selected_data[start_idx:end_idx]
|
||||
data_slice = selected_data
|
||||
|
||||
# Create output directory with process index
|
||||
base_path = Path('wandb') / args.wandb_exp_dir / \
|
||||
f"attrcond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}_attrs{'-'.join(attr_list)}"
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
evaluator = Evaluator(config, model, data_slice, vocab, device=device)
|
||||
|
||||
# Process assigned data slice
|
||||
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
|
||||
batch_dir = base_path
|
||||
batch_dir.mkdir(parents=True, exist_ok=True)
|
||||
evaluator.generate_samples_with_attrCtl(
|
||||
batch_dir,
|
||||
args.num_target_measure,
|
||||
tune_in_idx,
|
||||
tune_name,
|
||||
config.data_params.first_pred_feature,
|
||||
args.sampling_method,
|
||||
args.threshold,
|
||||
args.temperature,
|
||||
generation_length=args.generate_length,
|
||||
attr_list=attr_list
|
||||
)
|
||||
|
||||
|
||||
def unconditioned_worker(process_idx, gpu_id, args, num_samples):
|
||||
"""Worker process for unconditioned generation"""
|
||||
torch.cuda.set_device(gpu_id)
|
||||
device = torch.device(f'cuda:{gpu_id}')
|
||||
|
||||
# Load resources with proper device
|
||||
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
||||
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
|
||||
|
||||
# Create output directory with process index
|
||||
base_path = Path('wandb') / args.wandb_exp_dir / \
|
||||
@ -187,7 +262,7 @@ def text_conditioned_worker(process_idx, gpu_id, args, num_samples, data_slice):
|
||||
device = torch.device(f'cuda:{gpu_id}')
|
||||
|
||||
# Load resources with proper device
|
||||
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
||||
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
|
||||
|
||||
# Create output directory with process index
|
||||
base_path = Path('wandb') / args.wandb_exp_dir / \
|
||||
@ -237,40 +312,33 @@ def main():
|
||||
if not wandb_dir.exists():
|
||||
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
|
||||
|
||||
# Load test set to get selected tunes (dummy load to get dataset info)
|
||||
dummy_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
_, test_set, _ = prepare_model_and_dataset_from_config(
|
||||
wandb_dir / "files" / "config.yaml",
|
||||
wandb_dir / "files" / "metadata.json",
|
||||
wandb_dir / "files" / "vocab.json"
|
||||
)
|
||||
|
||||
if args.choose_selected_tunes and test_set.dataset == 'SOD':
|
||||
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
|
||||
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
|
||||
else:
|
||||
selected_tunes = [name for _, name in test_set.data_list][:args.num_samples]
|
||||
|
||||
# Split selected data across processes
|
||||
selected_data = [d for d in test_set.data_list if d[1] in selected_tunes]
|
||||
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
|
||||
|
||||
for i in range(args.num_processes):
|
||||
start_idx = i * chunk_size
|
||||
end_idx = min((i+1)*chunk_size, len(selected_data))
|
||||
data_slice = selected_data[start_idx:end_idx]
|
||||
|
||||
if not data_slice:
|
||||
continue
|
||||
|
||||
gpu_id = gpu_ids[i % len(gpu_ids)]
|
||||
p = Process(
|
||||
target=conditioned_worker,
|
||||
args=(i, gpu_id, args, data_slice)
|
||||
args=(i, gpu_id, args)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
elif args.generation_type == 'attr-conditioned':
|
||||
# Prepare selected tunes
|
||||
wandb_dir = Path('wandb') / args.wandb_exp_dir
|
||||
if not wandb_dir.exists():
|
||||
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
|
||||
|
||||
|
||||
for i in range(args.num_processes):
|
||||
|
||||
gpu_id = gpu_ids[i % len(gpu_ids)]
|
||||
p = Process(
|
||||
target=attr_conditioned_worker,
|
||||
args=(i, gpu_id, args)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
elif args.generation_type == 'unconditioned':
|
||||
samples_per_proc = args.num_samples // args.num_processes
|
||||
remainder = args.num_samples % args.num_processes
|
||||
|
||||
Reference in New Issue
Block a user