68 lines
2.8 KiB
Python
68 lines
2.8 KiB
Python
import argparse
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import numpy as np
|
|
import torch
|
|
from audioldm_eval import EvaluationHelper, EvaluationHelperParallel
|
|
import torch.multiprocessing as mp
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--generation_path", type=str, required=True, help="Path to generated audio files")
|
|
parser.add_argument("--target_path", type=str, required=True, help="Path to reference audio files")
|
|
parser.add_argument("--force_paired", action="store_true", help="Force pairing by randomly selecting reference files")
|
|
parser.add_argument("--gpu_mode", choices=["single", "multi"], default="single", help="Evaluation mode")
|
|
parser.add_argument("--num_gpus", type=int, default=2, help="Number of GPUs for multi-GPU mode")
|
|
args = parser.parse_args()
|
|
|
|
# Handle forced pairing
|
|
target_eval_path = args.target_path
|
|
temp_dir = None
|
|
if args.force_paired:
|
|
print(f"Using forced pairing with reference files from {args.target_path}")
|
|
temp_dir = tempfile.mkdtemp()
|
|
target_eval_path = temp_dir
|
|
|
|
# Collect generated filenames
|
|
gen_files = []
|
|
for root, _, files in os.walk(args.generation_path):
|
|
for file in files:
|
|
if file.endswith(".wav"):
|
|
gen_files.append(file)
|
|
print(f"Found {len(gen_files)} generated files in {args.generation_path}")
|
|
# Collect all reference files
|
|
ref_files = []
|
|
for root, _, files in os.walk(args.target_path):
|
|
for file in files:
|
|
if file.endswith(".wav"):
|
|
ref_files.append(os.path.join(root, file))
|
|
|
|
# Select random references matching the count
|
|
selected_refs = np.random.choice(ref_files, len(gen_files), replace=False)
|
|
print(f"Selected {len(selected_refs)} reference files for evaluation.")
|
|
# Copy selected references to temp dir with generated filenames
|
|
for gen_file, ref_path in zip(gen_files, selected_refs):
|
|
shutil.copy(ref_path, os.path.join(temp_dir, gen_file))
|
|
|
|
|
|
device = torch.device(f"cuda:{0}") if args.gpu_mode == "single" else None
|
|
|
|
try:
|
|
if args.gpu_mode == "single":
|
|
print("Running single GPU evaluation...")
|
|
evaluator = EvaluationHelper(16000, device)
|
|
metrics = evaluator.main(args.generation_path, target_eval_path)
|
|
else:
|
|
print(f"Running multi-GPU evaluation on {args.num_gpus} GPUs...")
|
|
evaluator = EvaluationHelperParallel(16000, args.num_gpus)
|
|
metrics = evaluator.main(args.generation_path, target_eval_path)
|
|
print("Evaluation completed.")
|
|
|
|
finally:
|
|
# Clean up temporary directory
|
|
if temp_dir and os.path.exists(temp_dir):
|
|
shutil.rmtree(temp_dir)
|
|
|
|
if __name__ == "__main__":
|
|
main() |