184 lines
8.6 KiB
Python
184 lines
8.6 KiB
Python
import os
|
|
import requests
|
|
from tqdm import tqdm
|
|
import torch
|
|
import numpy as np
|
|
import laion_clap
|
|
from clap_module.factory import load_state_dict
|
|
import librosa
|
|
import pyloudnorm as pyln
|
|
|
|
# following documentation from https://github.com/LAION-AI/CLAP
|
|
def int16_to_float32(x):
|
|
return (x / 32767.0).astype(np.float32)
|
|
|
|
def float32_to_int16(x):
|
|
x = np.clip(x, a_min=-1., a_max=1.)
|
|
return (x * 32767.).astype(np.int16)
|
|
|
|
|
|
def clap_score(id2text, audio_path, audio_files_extension='.wav', clap_model='music_speech_audioset_epoch_15_esc_89.98.pt'):
|
|
"""
|
|
Cosine similarity is computed between the LAION-CLAP text embedding of the given prompt and
|
|
the LAION-CLAP audio embedding of the generated audio. LION-CLAP: https://github.com/LAION-AI/CLAP
|
|
|
|
This evaluation script assumes that audio_path files are identified with the ids in id2text.
|
|
|
|
clap_score() evaluates all ids in id2text.
|
|
|
|
GPU-based computation.
|
|
|
|
Select one of the following models from https://github.com/LAION-AI/CLAP:
|
|
- music_speech_audioset_epoch_15_esc_89.98.pt (used by musicgen)
|
|
- music_audioset_epoch_15_esc_90.14.pt
|
|
- music_speech_epoch_15_esc_89.25.pt
|
|
- 630k-audioset-fusion-best.pt (our default, with "fusion" to handle longer inputs)
|
|
|
|
Params:
|
|
-- id2text: dictionary with the mapping between id (generated audio filenames in audio_path)
|
|
and text (prompt used to generate audio). clap_score() evaluates all ids in id2text.
|
|
-- audio_path: path where the generated audio files to evaluate are available.
|
|
-- audio_files_extension: files extension (default .wav) in eval_path.
|
|
-- clap_model: choose one of the above clap_models (default: '630k-audioset-fusion-best.pt').
|
|
Returns:
|
|
-- CLAP-LION score
|
|
"""
|
|
# load model
|
|
if clap_model == 'music_speech_audioset_epoch_15_esc_89.98.pt':
|
|
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_audioset_epoch_15_esc_89.98.pt'
|
|
clap_path = 'load/clap_score/music_speech_audioset_epoch_15_esc_89.98.pt'
|
|
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
|
|
elif clap_model == 'music_audioset_epoch_15_esc_90.14.pt':
|
|
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_audioset_epoch_15_esc_90.14.pt'
|
|
clap_path = 'load/clap_score/music_audioset_epoch_15_esc_90.14.pt'
|
|
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
|
|
elif clap_model == 'music_speech_epoch_15_esc_89.25.pt':
|
|
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_epoch_15_esc_89.25.pt'
|
|
clap_path = 'load/clap_score/music_speech_epoch_15_esc_89.25.pt'
|
|
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
|
|
elif clap_model == '630k-audioset-fusion-best.pt':
|
|
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/630k-audioset-fusion-best.pt'
|
|
clap_path = 'load/clap_score/630k-audioset-fusion-best.pt'
|
|
model = laion_clap.CLAP_Module(enable_fusion=True, device='cuda')
|
|
else:
|
|
raise ValueError('clap_model not implemented')
|
|
|
|
# download clap_model if not already downloaded
|
|
if not os.path.exists(clap_path):
|
|
print('Downloading ', clap_model, '...')
|
|
os.makedirs(os.path.dirname(clap_path), exist_ok=True)
|
|
|
|
response = requests.get(url, stream=True)
|
|
total_size = int(response.headers.get('content-length', 0))
|
|
|
|
with open(clap_path, 'wb') as file:
|
|
with tqdm(total=total_size, unit='B', unit_scale=True) as progress_bar:
|
|
for data in response.iter_content(chunk_size=8192):
|
|
file.write(data)
|
|
progress_bar.update(len(data))
|
|
|
|
# fixing CLAP-LION issue, see: https://github.com/LAION-AI/CLAP/issues/118
|
|
pkg = load_state_dict(clap_path)
|
|
pkg.pop('text_branch.embeddings.position_ids', None)
|
|
model.model.load_state_dict(pkg)
|
|
model.eval()
|
|
|
|
if not os.path.isdir(audio_path):
|
|
raise ValueError('audio_path does not exist')
|
|
|
|
if id2text:
|
|
print('[EXTRACTING TEXT EMBEDDINGS] ')
|
|
batch_size = 64
|
|
text_emb = {}
|
|
for i in tqdm(range(0, len(id2text), batch_size)):
|
|
batch_ids = list(id2text.keys())[i:i+batch_size]
|
|
batch_texts = [id2text[id] for id in batch_ids]
|
|
with torch.no_grad():
|
|
embeddings = model.get_text_embedding(batch_texts, use_tensor=True)
|
|
for id, emb in zip(batch_ids, embeddings):
|
|
text_emb[id] = emb
|
|
|
|
else:
|
|
raise ValueError('Must specify id2text')
|
|
|
|
print('[EVALUATING GENERATIONS] ', audio_path)
|
|
score = 0
|
|
count = 0
|
|
for id in tqdm(id2text.keys()):
|
|
file_path = os.path.join(audio_path, str(id)+audio_files_extension)
|
|
with torch.no_grad():
|
|
audio, _ = librosa.load(file_path, sr=48000, mono=True) # sample rate should be 48000
|
|
audio = pyln.normalize.peak(audio, -1.0)
|
|
audio = audio.reshape(1, -1) # unsqueeze (1,T)
|
|
audio = torch.from_numpy(int16_to_float32(float32_to_int16(audio))).float()
|
|
audio_embeddings = model.get_audio_embedding_from_data(x = audio, use_tensor=True)
|
|
cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_emb[id].unsqueeze(0), dim=1, eps=1e-8)[0]
|
|
score += cosine_sim
|
|
count += 1
|
|
|
|
return score / count if count > 0 else 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import pandas as pd
|
|
import json
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description='Compute CLAP score for generated audio files.')
|
|
parser.add_argument('--clap_model', type=str, default='630k-audioset-fusion-best.pt',
|
|
help='CLAP model to use for evaluation. Options: music_speech_audioset_epoch_15_esc_89.98.pt, music_audioset_epoch_15_esc_90.14.pt, music_speech_epoch_15_esc_89.25.pt, 630k-audioset-fusion-best.pt (default: 630k-audioset-fusion-best.pt)')
|
|
parser.add_argument('--root_path', type=str, default='../wandb/run-20250627_172105-xpe7nh5n-worseInstr/generated_samples_text_conditioned_top_p_threshold_0.99_temperature_1.15_8',
|
|
help='Path to the directory containing generated audio files and id2text mapping.')
|
|
args = parser.parse_args()
|
|
clap_model = args.clap_model
|
|
root_path = args.root_path
|
|
json_file_path = os.path.join(root_path, 'name2prompt.jsonl')
|
|
generated_path = os.path.join(root_path, 'prompt_music')
|
|
if not os.path.exists(generated_path):
|
|
generated_path =root_path # if no 'music' subfolder, use root_path directly
|
|
|
|
with open(json_file_path, 'r') as f:
|
|
id2text_dict = {}
|
|
for line in f:
|
|
item = json.loads(line)
|
|
for k, v in item.items():
|
|
id2text_dict[k] = v[0]
|
|
print('length of id2text:', len(id2text_dict))
|
|
# id2text = {k+'_1': v[0] for k, v in id2text_dict.items()} # assuming each key has a list of prompts, we take the first one
|
|
id2text ={}
|
|
for k, v in id2text_dict.items():
|
|
if isinstance(v, list):
|
|
id2text[k] = v[0]
|
|
# ckeck if k exist as wav file
|
|
if os.path.exists(os.path.join(generated_path, str(k)+'.wav')):
|
|
id2text[k] = v[0]
|
|
else:
|
|
# find k_*, k_1, k_2, ... and check if they exist
|
|
for i in range(0, 10): # assuming no more than 100 variations
|
|
if os.path.exists(os.path.join(generated_path, str(k)+'_'+str(i)+'.wav')):
|
|
new_key = str(k) + '_' + str(i)
|
|
id2text[new_key] = v[0]
|
|
print('length of id2text after checking wav files:', len(id2text))
|
|
# check if wav exsists
|
|
new_id2text = {}
|
|
for id in id2text.keys():
|
|
file_path = os.path.join(generated_path, str(id)+'.wav')
|
|
if os.path.exists(file_path):
|
|
new_id2text[id] = id2text[id]
|
|
else:
|
|
print(f"Warning: {file_path} does not exist, skipping this id.")
|
|
print('length of new_id2text:', len(new_id2text))
|
|
|
|
"""
|
|
IMPORTANT: the audios in generated_path should have the same ids as in id2text.
|
|
For musiccaps, you can load id2text as above and each generated_path audio file
|
|
corresponds to a prompt (text description) in musiccaps. Files are named with ids, as follows:
|
|
- your_model_outputs_folder/_-kssA-FOzU.wav
|
|
- your_model_outputs_folder/_0-2meOf9qY.wav
|
|
- your_model_outputs_folder/_1woPC5HWSg.wav
|
|
...
|
|
- your_model_outputs_folder/ZzyWbehtt0M.wav
|
|
"""
|
|
|
|
clp = clap_score(new_id2text, generated_path, audio_files_extension='.wav')
|
|
print('CLAP score (cosine similarity):', clp) |