first commit
This commit is contained in:
68
SongEval/matrics.py
Normal file
68
SongEval/matrics.py
Normal file
@ -0,0 +1,68 @@
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import torch
|
||||
from audioldm_eval import EvaluationHelper, EvaluationHelperParallel
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--generation_path", type=str, required=True, help="Path to generated audio files")
|
||||
parser.add_argument("--target_path", type=str, required=True, help="Path to reference audio files")
|
||||
parser.add_argument("--force_paired", action="store_true", help="Force pairing by randomly selecting reference files")
|
||||
parser.add_argument("--gpu_mode", choices=["single", "multi"], default="single", help="Evaluation mode")
|
||||
parser.add_argument("--num_gpus", type=int, default=2, help="Number of GPUs for multi-GPU mode")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle forced pairing
|
||||
target_eval_path = args.target_path
|
||||
temp_dir = None
|
||||
if args.force_paired:
|
||||
print(f"Using forced pairing with reference files from {args.target_path}")
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
target_eval_path = temp_dir
|
||||
|
||||
# Collect generated filenames
|
||||
gen_files = []
|
||||
for root, _, files in os.walk(args.generation_path):
|
||||
for file in files:
|
||||
if file.endswith(".wav"):
|
||||
gen_files.append(file)
|
||||
print(f"Found {len(gen_files)} generated files in {args.generation_path}")
|
||||
# Collect all reference files
|
||||
ref_files = []
|
||||
for root, _, files in os.walk(args.target_path):
|
||||
for file in files:
|
||||
if file.endswith(".wav"):
|
||||
ref_files.append(os.path.join(root, file))
|
||||
|
||||
# Select random references matching the count
|
||||
selected_refs = np.random.choice(ref_files, len(gen_files), replace=False)
|
||||
print(f"Selected {len(selected_refs)} reference files for evaluation.")
|
||||
# Copy selected references to temp dir with generated filenames
|
||||
for gen_file, ref_path in zip(gen_files, selected_refs):
|
||||
shutil.copy(ref_path, os.path.join(temp_dir, gen_file))
|
||||
|
||||
|
||||
device = torch.device(f"cuda:{0}") if args.gpu_mode == "single" else None
|
||||
|
||||
try:
|
||||
if args.gpu_mode == "single":
|
||||
print("Running single GPU evaluation...")
|
||||
evaluator = EvaluationHelper(16000, device)
|
||||
metrics = evaluator.main(args.generation_path, target_eval_path)
|
||||
else:
|
||||
print(f"Running multi-GPU evaluation on {args.num_gpus} GPUs...")
|
||||
evaluator = EvaluationHelperParallel(16000, args.num_gpus)
|
||||
metrics = evaluator.main(args.generation_path, target_eval_path)
|
||||
print("Evaluation completed.")
|
||||
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
if temp_dir and os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user