Files
MIDIFoundationModel/SongEval/eval.py
2025-09-08 14:49:28 +08:00

150 lines
5.1 KiB
Python

import glob
import os
import json
import librosa
import numpy as np
import torch
import argparse
from muq import MuQ
from hydra.utils import instantiate
from omegaconf import OmegaConf
from safetensors.torch import load_file
from tqdm import tqdm
class Synthesizer(object):
def __init__(self,
checkpoint_path,
input_path,
output_dir,
use_cpu: bool = False):
self.checkpoint_path = checkpoint_path
self.input_path = input_path
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
self.device = torch.device('cuda') if (torch.cuda.is_available() and (not use_cpu)) else torch.device('cpu')
@torch.no_grad()
def setup(self):
train_config = OmegaConf.load(os.path.join(os.path.dirname(self.checkpoint_path), '../config.yaml'))
model = instantiate(train_config.generator).to(self.device).eval()
state_dict = load_file(self.checkpoint_path, device="cpu")
model.load_state_dict(state_dict, strict=False)
self.model = model
self.muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
self.muq = self.muq.to(self.device).eval()
self.result_dcit = {}
@torch.no_grad()
def synthesis(self):
if os.path.isfile(self.input_path):
if self.input_path.endswith(('.wav', '.mp3')):
lines = []
lines.append(self.input_path)
else:
with open(self.input_path, "r") as f:
lines = [line for line in f]
input_files = [{
"input_path": line.strip(),
} for line in lines]
print(f"input filelst: {self.input_path}")
elif os.path.isdir(self.input_path):
input_files = [{
"input_path": file,
}for file in glob.glob(os.path.join(self.input_path, '*')) if file.lower().endswith(('.wav', '.mp3'))]
else:
raise ValueError(f"input_path {self.input_path} is not a file or directory")
for input in tqdm(input_files):
try:
self.handle(**input)
except Exception as e:
print(e)
continue
# add average
avg_values = {}
for key in self.result_dcit[list(self.result_dcit.keys())[0]].keys():
avg_values[key] = round(np.mean([self.result_dcit[fid][key] for fid in self.result_dcit]), 4)
self.result_dcit['average'] = avg_values
# save result
with open(os.path.join(self.output_dir, "result.json") , "w")as f:
json.dump(self.result_dcit, f, indent=4, ensure_ascii=False)
@torch.no_grad()
def handle(self, input_path):
fid = os.path.basename(input_path).split('.')[0]
if input_path.endswith('.npy'):
input = np.load(input_path)
# check ssl
if len(input.shape) == 3 and input.shape[0] != 1:
print('ssl_shape error', input_path)
return
if np.isnan(input).any():
print('ssl nan', input_path)
return
input = torch.from_numpy(input).to(self.device)
if len(input.shape) == 2:
input = input.unsqueeze(0)
if input_path.endswith(('.wav', '.mp3')):
wav, sr = librosa.load(input_path, sr=24000)
audio = torch.tensor(wav).unsqueeze(0).to(self.device)
output = self.muq(audio, output_hidden_states=True)
input = output["hidden_states"][6]
values = {}
scores_g = self.model(input).squeeze(0)
values['Coherence'] = round(scores_g[0].item(), 4)
values['Musicality'] = round(scores_g[1].item(), 4)
values['Memorability'] = round(scores_g[2].item(), 4)
values['Clarity'] = round(scores_g[3].item(), 4)
values['Naturalness'] = round(scores_g[4].item(), 4)
self.result_dcit[fid] = values
# delete
del input, output, scores_g, values,audio, wav, sr
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"-i", "--input_path",
type=str,
required=True,
help="Input audio: path to a single file, a text file listing audio paths, or a directory of audio files."
)
parser.add_argument(
"-o", "--output_dir",
type=str,
required=True,
help="Output directory for generated results (will be created if it doesn't exist)."
)
parser.add_argument(
"--use_cpu",
type=str,
help="Force CPU mode even if a GPU is available.",
default=False
)
args = parser.parse_args()
ckpt_path = "ckpt/model.safetensors"
synthesizer = Synthesizer(checkpoint_path=ckpt_path,
input_path=args.input_path,
output_dir=args.output_dir,
use_cpu=args.use_cpu)
synthesizer.setup()
synthesizer.synthesis()