1021 add flexable attr control

This commit is contained in:
FelixChan
2025-10-21 15:27:03 +08:00
parent d6b68ef90b
commit b493ede479
15 changed files with 400 additions and 394 deletions

View File

@ -1,3 +1,4 @@
from ast import arg
import sys
import os
from pathlib import Path
@ -25,14 +26,25 @@ def get_argument_parser():
parser.add_argument(
"-generation_type",
type=str,
choices=('conditioned', 'unconditioned', 'text-conditioned'),
choices=('conditioned', 'unconditioned', 'text-conditioned', 'attr-conditioned'),
default='unconditioned',
help="generation type",
)
parser.add_argument(
"-attr_list",
type=str,
default="beat,duration",
help="attribute list for attribute-controlled generation",
)
parser.add_argument(
"-dataset",
type=str,
help="dataset name, only for conditioned generation",
)
parser.add_argument(
"-sampling_method",
type=str,
choices=('top_p', 'top_k'),
choices=('top_p', 'top_k', 'min_p'),
default='top_p',
help="sampling method",
)
@ -74,7 +86,7 @@ def get_argument_parser():
parser.add_argument(
"-num_processes",
type=int,
default=4,
default=1,
help="number of processes to use",
)
parser.add_argument(
@ -97,7 +109,7 @@ def get_argument_parser():
)
return parser
def load_resources(wandb_exp_dir, device):
def load_resources(wandb_exp_dir, condition_dataset, device):
"""Load model and dataset resources for a process"""
wandb_dir = Path('wandb')
ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir)
@ -107,7 +119,8 @@ def load_resources(wandb_exp_dir, device):
# Load checkpoint to specified device
print("Loading checkpoint from:", ckpt_path)
ckpt = torch.load(ckpt_path, map_location=device)
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path)
print(config)
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path, condition_dataset)
model.load_state_dict(ckpt['model'], strict=False)
model.to(device)
model.eval()
@ -123,20 +136,33 @@ def load_resources(wandb_exp_dir, device):
return config, model, dataset_for_prompt, vocab
def conditioned_worker(process_idx, gpu_id, args, data_slice):
def conditioned_worker(process_idx, gpu_id, args):
"""Worker process for conditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# print(test_set)
if args.choose_selected_tunes and test_set.dataset == 'SOD':
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
else:
selected_tunes = [name for _, name in test_set][:args.num_samples]
# Split selected data across processes
selected_data = [d for d in test_set if d[1] in selected_tunes]
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
start_idx = 1
end_idx = min(chunk_size, len(selected_data))
data_slice = selected_data[start_idx:end_idx]
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
base_path.mkdir(parents=True, exist_ok=True)
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device)
evaluator = Evaluator(config, model, data_slice, vocab, device=device)
# Process assigned data slice
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
@ -154,13 +180,62 @@ def conditioned_worker(process_idx, gpu_id, args, data_slice):
generation_length=args.generate_length
)
def attr_conditioned_worker(process_idx, gpu_id, args):
"""Worker process for conditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# attr_list = "position,duration"
attr_list = args.attr_list.split(',')
# Load resources with proper device
config, model, test_set, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# print(test_set)
if args.choose_selected_tunes and test_set.dataset == 'SOD':
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
else:
selected_tunes = [name for _, name in test_set][:args.num_samples]
# Split selected data across processes
selected_data = [d for d in test_set if d[1] in selected_tunes]
# chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
# start_idx = 1
# end_idx = min(chunk_size, len(selected_data))
# data_slice = selected_data[start_idx:end_idx]
data_slice = selected_data
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"attrcond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}_attrs{'-'.join(attr_list)}"
base_path.mkdir(parents=True, exist_ok=True)
evaluator = Evaluator(config, model, data_slice, vocab, device=device)
# Process assigned data slice
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
batch_dir = base_path
batch_dir.mkdir(parents=True, exist_ok=True)
evaluator.generate_samples_with_attrCtl(
batch_dir,
args.num_target_measure,
tune_in_idx,
tune_name,
config.data_params.first_pred_feature,
args.sampling_method,
args.threshold,
args.temperature,
generation_length=args.generate_length,
attr_list=attr_list
)
def unconditioned_worker(process_idx, gpu_id, args, num_samples):
"""Worker process for unconditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
@ -187,7 +262,7 @@ def text_conditioned_worker(process_idx, gpu_id, args, num_samples, data_slice):
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, args.dataset, device)
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
@ -237,40 +312,33 @@ def main():
if not wandb_dir.exists():
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
# Load test set to get selected tunes (dummy load to get dataset info)
dummy_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_, test_set, _ = prepare_model_and_dataset_from_config(
wandb_dir / "files" / "config.yaml",
wandb_dir / "files" / "metadata.json",
wandb_dir / "files" / "vocab.json"
)
if args.choose_selected_tunes and test_set.dataset == 'SOD':
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
else:
selected_tunes = [name for _, name in test_set.data_list][:args.num_samples]
# Split selected data across processes
selected_data = [d for d in test_set.data_list if d[1] in selected_tunes]
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
for i in range(args.num_processes):
start_idx = i * chunk_size
end_idx = min((i+1)*chunk_size, len(selected_data))
data_slice = selected_data[start_idx:end_idx]
if not data_slice:
continue
gpu_id = gpu_ids[i % len(gpu_ids)]
p = Process(
target=conditioned_worker,
args=(i, gpu_id, args, data_slice)
args=(i, gpu_id, args)
)
processes.append(p)
p.start()
elif args.generation_type == 'attr-conditioned':
# Prepare selected tunes
wandb_dir = Path('wandb') / args.wandb_exp_dir
if not wandb_dir.exists():
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
for i in range(args.num_processes):
gpu_id = gpu_ids[i % len(gpu_ids)]
p = Process(
target=attr_conditioned_worker,
args=(i, gpu_id, args)
)
processes.append(p)
p.start()
elif args.generation_type == 'unconditioned':
samples_per_proc = args.num_samples // args.num_processes
remainder = args.num_samples % args.num_processes