import argparse import os import shutil import tempfile import numpy as np import torch from audioldm_eval import EvaluationHelper, EvaluationHelperParallel import torch.multiprocessing as mp def main(): parser = argparse.ArgumentParser() parser.add_argument("--generation_path", type=str, required=True, help="Path to generated audio files") parser.add_argument("--target_path", type=str, required=True, help="Path to reference audio files") parser.add_argument("--force_paired", action="store_true", help="Force pairing by randomly selecting reference files") parser.add_argument("--gpu_mode", choices=["single", "multi"], default="single", help="Evaluation mode") parser.add_argument("--num_gpus", type=int, default=2, help="Number of GPUs for multi-GPU mode") args = parser.parse_args() # Handle forced pairing target_eval_path = args.target_path temp_dir = None if args.force_paired: print(f"Using forced pairing with reference files from {args.target_path}") temp_dir = tempfile.mkdtemp() target_eval_path = temp_dir # Collect generated filenames gen_files = [] for root, _, files in os.walk(args.generation_path): for file in files: if file.endswith(".wav"): gen_files.append(file) print(f"Found {len(gen_files)} generated files in {args.generation_path}") # Collect all reference files ref_files = [] for root, _, files in os.walk(args.target_path): for file in files: if file.endswith(".wav"): ref_files.append(os.path.join(root, file)) # Select random references matching the count selected_refs = np.random.choice(ref_files, len(gen_files), replace=False) print(f"Selected {len(selected_refs)} reference files for evaluation.") # Copy selected references to temp dir with generated filenames for gen_file, ref_path in zip(gen_files, selected_refs): shutil.copy(ref_path, os.path.join(temp_dir, gen_file)) device = torch.device(f"cuda:{0}") if args.gpu_mode == "single" else None try: if args.gpu_mode == "single": print("Running single GPU evaluation...") evaluator = EvaluationHelper(16000, device) metrics = evaluator.main(args.generation_path, target_eval_path) else: print(f"Running multi-GPU evaluation on {args.num_gpus} GPUs...") evaluator = EvaluationHelperParallel(16000, args.num_gpus) metrics = evaluator.main(args.generation_path, target_eval_path) print("Evaluation completed.") finally: # Clean up temporary directory if temp_dir and os.path.exists(temp_dir): shutil.rmtree(temp_dir) if __name__ == "__main__": main()