first commit
This commit is contained in:
BIN
SongEval/.DS_Store
vendored
Normal file
BIN
SongEval/.DS_Store
vendored
Normal file
Binary file not shown.
201
SongEval/LICENSE
Normal file
201
SongEval/LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
88
SongEval/README.md
Normal file
88
SongEval/README.md
Normal file
@ -0,0 +1,88 @@
|
||||
# 🎵 SongEval: A Benchmark Dataset for Song Aesthetics Evaluation
|
||||
|
||||
[](https://huggingface.co/datasets/ASLP-lab/SongEval)
|
||||
[](https://arxiv.org/pdf/2505.10793)
|
||||
[](https://creativecommons.org/licenses/by-nc-sa/4.0/)
|
||||
|
||||
|
||||
This repository provides a **trained aesthetic evaluation toolkit** based on [SongEval](https://huggingface.co/datasets/ASLP-lab/SongEval), the first large-scale, open-source dataset for human-perceived song aesthetics. The toolkit enables **automatic scoring of generated song** across five perceptual aesthetic dimensions aligned with professional musician judgments.
|
||||
|
||||
---
|
||||
|
||||
## 🌟 Key Features
|
||||
|
||||
- 🧠 **Pretrained neural models** for perceptual aesthetic evaluation
|
||||
- 🎼 Predicts **five aesthetic dimensions**:
|
||||
- Overall Coherence
|
||||
- Memorability
|
||||
- Naturalness of Vocal Breathing and Phrasing
|
||||
- Clarity of Song Structure
|
||||
- Overall Musicality
|
||||
<!-- - 🧪 Supports **batch evaluation** for model benchmarking -->
|
||||
- 🎧 Accepts **full-length songs** (vocals + accompaniment) as input
|
||||
- ⚙️ Simple inference interface
|
||||
|
||||
---
|
||||
|
||||
## 📦 Installation
|
||||
|
||||
Clone the repository and install dependencies:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/ASLP-lab/SongEval.git
|
||||
cd SongEval
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
- Evaluate a single audio file:
|
||||
|
||||
```bash
|
||||
python eval.py -i /path/to/audio.mp3 -o /path/to/output
|
||||
```
|
||||
|
||||
- Evaluate a list of audio files:
|
||||
|
||||
```bash
|
||||
python eval.py -i /path/to/audio_list.txt -o /path/to/output
|
||||
```
|
||||
|
||||
- Evaluate all audio files in a directory:
|
||||
|
||||
```bash
|
||||
python eval.py -i /path/to/audio_directory -o /path/to/output
|
||||
```
|
||||
|
||||
- Force evaluation on CPU (⚠️ CPU evaluation may be significantly slower) :
|
||||
|
||||
|
||||
```bash
|
||||
python eval.py -i /path/to/audio.wav -o /path/to/output --use_cpu True
|
||||
```
|
||||
|
||||
|
||||
## 🙏 Acknowledgement
|
||||
This project is mainly organized by the audio, speech and language processing lab [(ASLP@NPU)](http://www.npu-aslp.org/).
|
||||
|
||||
We sincerely thank the **Shanghai Conservatory of Music** for their expert guidance on music theory, aesthetics, and annotation design.
|
||||
Meanwhile, we thank AISHELL to help with the orgnization of the song annotations.
|
||||
|
||||
<p align="center"> <img src="assets/logo.png" alt="Shanghai Conservatory of Music Logo"/> </p>
|
||||
|
||||
## 📑 License
|
||||
This project is released under the CC BY-NC-SA 4.0 license.
|
||||
|
||||
You are free to use, modify, and build upon it for non-commercial purposes, with attribution.
|
||||
|
||||
## 📚 Citation
|
||||
If you use this toolkit or the SongEval dataset, please cite the following:
|
||||
```
|
||||
@article{yao2025songeval,
|
||||
title = {SongEval: A Benchmark Dataset for Song Aesthetics Evaluation},
|
||||
author = {Yao, Jixun and Ma, Guobin and Xue, Huixin and Chen, Huakang and Hao, Chunbo and Jiang, Yuepeng and Liu, Haohe and Yuan, Ruibin and Xu, Jin and Xue, Wei and others},
|
||||
journal = {arXiv preprint arXiv:2505.10793},
|
||||
year={2025}
|
||||
}
|
||||
|
||||
```
|
||||
BIN
SongEval/assets/logo.png
Normal file
BIN
SongEval/assets/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1016 KiB |
184
SongEval/clap_score.py
Normal file
184
SongEval/clap_score.py
Normal file
@ -0,0 +1,184 @@
|
||||
import os
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import numpy as np
|
||||
import laion_clap
|
||||
from clap_module.factory import load_state_dict
|
||||
import librosa
|
||||
import pyloudnorm as pyln
|
||||
|
||||
# following documentation from https://github.com/LAION-AI/CLAP
|
||||
def int16_to_float32(x):
|
||||
return (x / 32767.0).astype(np.float32)
|
||||
|
||||
def float32_to_int16(x):
|
||||
x = np.clip(x, a_min=-1., a_max=1.)
|
||||
return (x * 32767.).astype(np.int16)
|
||||
|
||||
|
||||
def clap_score(id2text, audio_path, audio_files_extension='.wav', clap_model='music_speech_audioset_epoch_15_esc_89.98.pt'):
|
||||
"""
|
||||
Cosine similarity is computed between the LAION-CLAP text embedding of the given prompt and
|
||||
the LAION-CLAP audio embedding of the generated audio. LION-CLAP: https://github.com/LAION-AI/CLAP
|
||||
|
||||
This evaluation script assumes that audio_path files are identified with the ids in id2text.
|
||||
|
||||
clap_score() evaluates all ids in id2text.
|
||||
|
||||
GPU-based computation.
|
||||
|
||||
Select one of the following models from https://github.com/LAION-AI/CLAP:
|
||||
- music_speech_audioset_epoch_15_esc_89.98.pt (used by musicgen)
|
||||
- music_audioset_epoch_15_esc_90.14.pt
|
||||
- music_speech_epoch_15_esc_89.25.pt
|
||||
- 630k-audioset-fusion-best.pt (our default, with "fusion" to handle longer inputs)
|
||||
|
||||
Params:
|
||||
-- id2text: dictionary with the mapping between id (generated audio filenames in audio_path)
|
||||
and text (prompt used to generate audio). clap_score() evaluates all ids in id2text.
|
||||
-- audio_path: path where the generated audio files to evaluate are available.
|
||||
-- audio_files_extension: files extension (default .wav) in eval_path.
|
||||
-- clap_model: choose one of the above clap_models (default: '630k-audioset-fusion-best.pt').
|
||||
Returns:
|
||||
-- CLAP-LION score
|
||||
"""
|
||||
# load model
|
||||
if clap_model == 'music_speech_audioset_epoch_15_esc_89.98.pt':
|
||||
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_audioset_epoch_15_esc_89.98.pt'
|
||||
clap_path = 'load/clap_score/music_speech_audioset_epoch_15_esc_89.98.pt'
|
||||
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
|
||||
elif clap_model == 'music_audioset_epoch_15_esc_90.14.pt':
|
||||
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_audioset_epoch_15_esc_90.14.pt'
|
||||
clap_path = 'load/clap_score/music_audioset_epoch_15_esc_90.14.pt'
|
||||
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
|
||||
elif clap_model == 'music_speech_epoch_15_esc_89.25.pt':
|
||||
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_epoch_15_esc_89.25.pt'
|
||||
clap_path = 'load/clap_score/music_speech_epoch_15_esc_89.25.pt'
|
||||
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
|
||||
elif clap_model == '630k-audioset-fusion-best.pt':
|
||||
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/630k-audioset-fusion-best.pt'
|
||||
clap_path = 'load/clap_score/630k-audioset-fusion-best.pt'
|
||||
model = laion_clap.CLAP_Module(enable_fusion=True, device='cuda')
|
||||
else:
|
||||
raise ValueError('clap_model not implemented')
|
||||
|
||||
# download clap_model if not already downloaded
|
||||
if not os.path.exists(clap_path):
|
||||
print('Downloading ', clap_model, '...')
|
||||
os.makedirs(os.path.dirname(clap_path), exist_ok=True)
|
||||
|
||||
response = requests.get(url, stream=True)
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
|
||||
with open(clap_path, 'wb') as file:
|
||||
with tqdm(total=total_size, unit='B', unit_scale=True) as progress_bar:
|
||||
for data in response.iter_content(chunk_size=8192):
|
||||
file.write(data)
|
||||
progress_bar.update(len(data))
|
||||
|
||||
# fixing CLAP-LION issue, see: https://github.com/LAION-AI/CLAP/issues/118
|
||||
pkg = load_state_dict(clap_path)
|
||||
pkg.pop('text_branch.embeddings.position_ids', None)
|
||||
model.model.load_state_dict(pkg)
|
||||
model.eval()
|
||||
|
||||
if not os.path.isdir(audio_path):
|
||||
raise ValueError('audio_path does not exist')
|
||||
|
||||
if id2text:
|
||||
print('[EXTRACTING TEXT EMBEDDINGS] ')
|
||||
batch_size = 64
|
||||
text_emb = {}
|
||||
for i in tqdm(range(0, len(id2text), batch_size)):
|
||||
batch_ids = list(id2text.keys())[i:i+batch_size]
|
||||
batch_texts = [id2text[id] for id in batch_ids]
|
||||
with torch.no_grad():
|
||||
embeddings = model.get_text_embedding(batch_texts, use_tensor=True)
|
||||
for id, emb in zip(batch_ids, embeddings):
|
||||
text_emb[id] = emb
|
||||
|
||||
else:
|
||||
raise ValueError('Must specify id2text')
|
||||
|
||||
print('[EVALUATING GENERATIONS] ', audio_path)
|
||||
score = 0
|
||||
count = 0
|
||||
for id in tqdm(id2text.keys()):
|
||||
file_path = os.path.join(audio_path, str(id)+audio_files_extension)
|
||||
with torch.no_grad():
|
||||
audio, _ = librosa.load(file_path, sr=48000, mono=True) # sample rate should be 48000
|
||||
audio = pyln.normalize.peak(audio, -1.0)
|
||||
audio = audio.reshape(1, -1) # unsqueeze (1,T)
|
||||
audio = torch.from_numpy(int16_to_float32(float32_to_int16(audio))).float()
|
||||
audio_embeddings = model.get_audio_embedding_from_data(x = audio, use_tensor=True)
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_emb[id].unsqueeze(0), dim=1, eps=1e-8)[0]
|
||||
score += cosine_sim
|
||||
count += 1
|
||||
|
||||
return score / count if count > 0 else 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
import pandas as pd
|
||||
import json
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='Compute CLAP score for generated audio files.')
|
||||
parser.add_argument('--clap_model', type=str, default='630k-audioset-fusion-best.pt',
|
||||
help='CLAP model to use for evaluation. Options: music_speech_audioset_epoch_15_esc_89.98.pt, music_audioset_epoch_15_esc_90.14.pt, music_speech_epoch_15_esc_89.25.pt, 630k-audioset-fusion-best.pt (default: 630k-audioset-fusion-best.pt)')
|
||||
parser.add_argument('--root_path', type=str, default='../wandb/run-20250627_172105-xpe7nh5n-worseInstr/generated_samples_text_conditioned_top_p_threshold_0.99_temperature_1.15_8',
|
||||
help='Path to the directory containing generated audio files and id2text mapping.')
|
||||
args = parser.parse_args()
|
||||
clap_model = args.clap_model
|
||||
root_path = args.root_path
|
||||
json_file_path = os.path.join(root_path, 'name2prompt.jsonl')
|
||||
generated_path = os.path.join(root_path, 'prompt_music')
|
||||
if not os.path.exists(generated_path):
|
||||
generated_path =root_path # if no 'music' subfolder, use root_path directly
|
||||
|
||||
with open(json_file_path, 'r') as f:
|
||||
id2text_dict = {}
|
||||
for line in f:
|
||||
item = json.loads(line)
|
||||
for k, v in item.items():
|
||||
id2text_dict[k] = v[0]
|
||||
print('length of id2text:', len(id2text_dict))
|
||||
# id2text = {k+'_1': v[0] for k, v in id2text_dict.items()} # assuming each key has a list of prompts, we take the first one
|
||||
id2text ={}
|
||||
for k, v in id2text_dict.items():
|
||||
if isinstance(v, list):
|
||||
id2text[k] = v[0]
|
||||
# ckeck if k exist as wav file
|
||||
if os.path.exists(os.path.join(generated_path, str(k)+'.wav')):
|
||||
id2text[k] = v[0]
|
||||
else:
|
||||
# find k_*, k_1, k_2, ... and check if they exist
|
||||
for i in range(0, 10): # assuming no more than 100 variations
|
||||
if os.path.exists(os.path.join(generated_path, str(k)+'_'+str(i)+'.wav')):
|
||||
new_key = str(k) + '_' + str(i)
|
||||
id2text[new_key] = v[0]
|
||||
print('length of id2text after checking wav files:', len(id2text))
|
||||
# check if wav exsists
|
||||
new_id2text = {}
|
||||
for id in id2text.keys():
|
||||
file_path = os.path.join(generated_path, str(id)+'.wav')
|
||||
if os.path.exists(file_path):
|
||||
new_id2text[id] = id2text[id]
|
||||
else:
|
||||
print(f"Warning: {file_path} does not exist, skipping this id.")
|
||||
print('length of new_id2text:', len(new_id2text))
|
||||
|
||||
"""
|
||||
IMPORTANT: the audios in generated_path should have the same ids as in id2text.
|
||||
For musiccaps, you can load id2text as above and each generated_path audio file
|
||||
corresponds to a prompt (text description) in musiccaps. Files are named with ids, as follows:
|
||||
- your_model_outputs_folder/_-kssA-FOzU.wav
|
||||
- your_model_outputs_folder/_0-2meOf9qY.wav
|
||||
- your_model_outputs_folder/_1woPC5HWSg.wav
|
||||
...
|
||||
- your_model_outputs_folder/ZzyWbehtt0M.wav
|
||||
"""
|
||||
|
||||
clp = clap_score(new_id2text, generated_path, audio_files_extension='.wav')
|
||||
print('CLAP score (cosine similarity):', clp)
|
||||
6
SongEval/config.yaml
Normal file
6
SongEval/config.yaml
Normal file
@ -0,0 +1,6 @@
|
||||
generator:
|
||||
_target_: model.Generator
|
||||
in_features: 1024
|
||||
ffd_hidden_size: 4096
|
||||
num_classes: 5
|
||||
attn_layer_num: 4
|
||||
456
SongEval/controlability.py
Normal file
456
SongEval/controlability.py
Normal file
@ -0,0 +1,456 @@
|
||||
import json
|
||||
|
||||
generate_path = 'Text2midi/muzic/musecoco/2-attribute2music_model/generation/0505/linear_mask-1billion-attribute2music/infer_test/topk15-t1.0-ngram0/all_midis'
|
||||
# generate_path = 'Text2midi/t2m-inferalign/text2midi_infer_output'
|
||||
# generate_path = 'wandb/no-disp-no-ciem/text_condi_top_p_t0.99_temp1.25'
|
||||
test_set_json = "dataset/midicaps/train.json"
|
||||
|
||||
generated_eval_json_path = f"{generate_path}/eval.json"
|
||||
generated_name2prompt_jsonl_path = f"{generate_path}/name2prompt.jsonl"
|
||||
|
||||
# 1. 读取 test_set,建立 prompt 到条目的映射
|
||||
with open(test_set_json, 'r') as f:
|
||||
test_set =[]
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
item = json.loads(line.strip())
|
||||
test_set.append(item)
|
||||
prompt2item = {item['caption']: item for item in test_set if item['test_set'] is True}
|
||||
print(f"Number of prompts in test set: {len(prompt2item)}")
|
||||
# 2. 读取 name2prompt.jsonl,建立 name 到 prompt 的映射
|
||||
name2prompt = {}
|
||||
with open(generated_name2prompt_jsonl_path, 'r') as f:
|
||||
for line in f:
|
||||
obj = json.loads(line)
|
||||
name2prompt.update({k: v[0] for k, v in obj.items() if isinstance(v, list) and len(v) > 0})
|
||||
# 3. 读取 eval.json
|
||||
with open(generated_eval_json_path, 'r') as f:
|
||||
eval_items = []
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
item = json.loads(line.strip())
|
||||
eval_items.append(item)
|
||||
|
||||
# 4. 对每个 name,找到对应的 prompt,确保 prompt 在 test_set 里,然后找到 eval.json 里对应的条目
|
||||
results = []
|
||||
# turn the name of eval_items into relative name
|
||||
for item in eval_items:
|
||||
item['name'] = item['name'].split('/')[-1] # 假设 name 是一个路径,取最后一部分作为相对名称
|
||||
# 去掉第二个下划线后面的内容
|
||||
if '_' in item['name']:
|
||||
item['name'] = item['name'].split('.')[0].split('_')[0] + '_' + item['name'].split('.')[0].split('_')[1]
|
||||
# print(f"Processed eval item name: {item['name']}")
|
||||
|
||||
for name, prompt in name2prompt.items():
|
||||
if prompt not in prompt2item:
|
||||
print(f"Prompt not found in test set: {prompt}")
|
||||
continue
|
||||
# 找到 eval.json 里对应的条目(假设 eval.json 里有 name 字段)
|
||||
eval_entry = next((item for item in eval_items if item.get('name') == name), None)
|
||||
if eval_entry is None:
|
||||
print(f"Eval entry not found for name: {name}")
|
||||
continue
|
||||
# 原始条目
|
||||
original_entry = prompt2item[prompt]
|
||||
results.append({
|
||||
'name': name,
|
||||
'prompt': prompt,
|
||||
'eval_entry': eval_entry,
|
||||
'original_entry': original_entry
|
||||
})
|
||||
print(f"Number of results: {len(results)}")
|
||||
print(f"Sample result: {results[0] if results else 'No results'}")
|
||||
|
||||
def calculate_TBT_score(results):
|
||||
"""
|
||||
• Tempo Bin with Tolerance (TBT): The predicted bpm falls into the ground truth tempo bin or
|
||||
a neighboring one.
|
||||
"""
|
||||
correct = 0
|
||||
total = 0
|
||||
for result in results:
|
||||
eval_entry = result['eval_entry']
|
||||
original_entry = result['original_entry']
|
||||
if 'tempo' in eval_entry and 'tempo' in original_entry:
|
||||
eval_tempo = eval_entry['tempo'][0] if isinstance(eval_entry['tempo'], list) else eval_entry['tempo']
|
||||
original_tempo = original_entry['tempo']
|
||||
if original_tempo is None or eval_tempo is None:
|
||||
continue # 如果原始条目没有 tempo,跳过
|
||||
# 检查 eval_tempo 是否在 original_tempo 的范围内
|
||||
if original_tempo - 10 <= eval_tempo <= original_tempo + 15:
|
||||
correct += 1
|
||||
total += 1
|
||||
TB_score = correct / total if total > 0 else 0
|
||||
print(f"TB Score: {TB_score:.4f} (Correct: {correct}, Total: {total})")
|
||||
return TB_score
|
||||
|
||||
def calculate_CK_score(results):
|
||||
"""
|
||||
• Correct Key (CK): The predicted key matches the ground truth key.
|
||||
"""
|
||||
correct = 0
|
||||
total = 0
|
||||
for result in results:
|
||||
eval_entry = result['eval_entry']
|
||||
original_entry = result['original_entry']
|
||||
if 'key' in eval_entry and 'key' in original_entry:
|
||||
eval_key = eval_entry['key'][0] if isinstance(eval_entry['key'], list) else eval_entry['key']
|
||||
eval_key = eval_key if eval_key is not None else "C major" # 默认值为 C 大调
|
||||
original_key = original_entry['key'] if original_entry['key'] is not None else "C major" # 默认值为 C 大调
|
||||
if original_key is None or eval_key is None:
|
||||
continue
|
||||
if eval_key == original_key:
|
||||
correct += 1
|
||||
total += 1
|
||||
CK_score = correct / total if total > 0 else 0
|
||||
print(f"CK Score: {CK_score:.4f} (Correct: {correct}, Total: {total})")
|
||||
return CK_score
|
||||
def calculate_CKD_score(results):
|
||||
"""
|
||||
Correct Key with Duplicates (CKD): The predicted key matches the ground truth key or an equivalent key (i.e., a major key and its relative minor).
|
||||
"""
|
||||
correct = 0
|
||||
total = 0
|
||||
for result in results:
|
||||
eval_entry = result['eval_entry']
|
||||
original_entry = result['original_entry']
|
||||
if 'key' in eval_entry and 'key' in original_entry:
|
||||
eval_key = eval_entry['key'][0] if isinstance(eval_entry['key'], list) else eval_entry['key']
|
||||
if eval_key is None:
|
||||
eval_key = "C major" # 默认值为 C 大调
|
||||
original_key = original_entry['key'] if original_entry['key'] is not None else "C major"
|
||||
if original_key is None or eval_key is None:
|
||||
continue # 如果原始条目没有 key,跳过
|
||||
# 检查 eval_key 是否与 original_key 相同或是其相对小调
|
||||
if eval_key == original_key or (eval_key.split(' ')[0] == original_key.split(' ')[0]):
|
||||
correct += 1
|
||||
total += 1
|
||||
CKD_score = correct / total if total > 0 else 0
|
||||
print(f"CKD Score: {CKD_score:.4f} (Correct: {correct}, Total: {total})")
|
||||
return CKD_score
|
||||
|
||||
def calculate_CTS_score(results):
|
||||
"""
|
||||
• Correct Time Signature (CTS): The predicted time signature matches the ground truth time signature.
|
||||
"""
|
||||
correct = 0
|
||||
total = 0
|
||||
for result in results:
|
||||
eval_entry = result['eval_entry']
|
||||
original_entry = result['original_entry']
|
||||
if 'time_signature' in eval_entry and 'time_signature' in original_entry:
|
||||
eval_time_signature = eval_entry['time_signature'][0] if isinstance(eval_entry['time_signature'], list) else eval_entry['time_signature']
|
||||
original_time_signature = original_entry['time_signature']
|
||||
if original_time_signature is None or eval_time_signature is None:
|
||||
continue # 如果原始条目没有 time signature,跳过
|
||||
if eval_time_signature == original_time_signature:
|
||||
correct += 1
|
||||
else:
|
||||
# 检查是否为相同的节拍(如 4/4 和 2/2)
|
||||
eval_numerator, eval_denominator = map(int, eval_time_signature.split('/'))
|
||||
original_numerator, original_denominator = map(int, original_time_signature.split('/'))
|
||||
if (eval_numerator == original_numerator and eval_denominator == original_denominator) or \
|
||||
(eval_numerator * 2 == original_numerator and eval_denominator == original_denominator):
|
||||
correct += 1
|
||||
total += 1
|
||||
CTS_score = correct / total if total > 0 else 0
|
||||
print(f"CTS Score: {CTS_score:.4f} (Correct: {correct}, Total: {total})")
|
||||
return CTS_score
|
||||
|
||||
def calculate_ECM_score(results):
|
||||
"""
|
||||
Exact Chord Match (ECM): The predicted
|
||||
chord sequence matches the ground truth exactly
|
||||
in terms of order, chord root, and chord type, with
|
||||
tolerance for missing and excess chord instances.
|
||||
"""
|
||||
correct = 0
|
||||
total = 0
|
||||
for result in results:
|
||||
eval_entry = result['eval_entry']
|
||||
original_entry = result['original_entry']
|
||||
if 'chord_summary' in eval_entry and 'chord_summary' in original_entry:
|
||||
eval_chord_summary = eval_entry['chord_summary'][0] if isinstance(eval_entry['chord_summary'], list) else eval_entry['chord_summary']
|
||||
original_chord_summary = original_entry['chord_summary']
|
||||
if original_chord_summary is None or eval_chord_summary is None:
|
||||
continue
|
||||
# 检查 eval_chord_summary 是否包含 original_chord_summary,两个都是列表,每个元素是一个字符串
|
||||
if eval_chord_summary == original_chord_summary:
|
||||
correct += 1
|
||||
total += 1
|
||||
ECM_score = correct / total if total > 0 else 0
|
||||
print(f"ECM Score: {ECM_score:.4f} (Correct: {correct}, Total: {total})")
|
||||
return ECM_score
|
||||
|
||||
def calculate_CMO_score(results):
|
||||
"""
|
||||
• Chord Match in any Order (CMO): The portion of predicted chord sequence matching the
|
||||
ground truth chord root and type, in any order
|
||||
"""
|
||||
correct = 0
|
||||
total = 0
|
||||
for result in results:
|
||||
eval_entry = result['eval_entry']
|
||||
original_entry = result['original_entry']
|
||||
if 'chords' in eval_entry and 'chord_summary' in original_entry:
|
||||
eval_chords_seq = eval_entry['chords']
|
||||
# remove the confidence score from eval_chords_seq
|
||||
if isinstance(eval_chords_seq, list) and len(eval_chords_seq) > 0 and isinstance(eval_chords_seq[0], list):
|
||||
eval_chords_seq = [chord[0] for chord in eval_chords_seq]
|
||||
original_chord_summary = original_entry['chord_summary']
|
||||
if original_chord_summary is None or eval_chords_seq is None:
|
||||
continue
|
||||
# 检查 eval_chords_seq 是否包含 original_chord_summary,两个都是列表
|
||||
eval_chords_set = set(eval_chords_seq) # [['C', 0.464399092], ['G', 2.879274376]]
|
||||
original_chord_set = set(original_chord_summary) # ['G', 'C']
|
||||
if original_chord_set.issubset(eval_chords_set):
|
||||
correct += 1
|
||||
else:
|
||||
if original_chord_set == eval_chords_set:
|
||||
correct += 1
|
||||
total += 1
|
||||
CMO_score = correct / total if total > 0 else 0
|
||||
print(f"CMO Score: {CMO_score:.4f} (Correct: {correct}, Total: {total})")
|
||||
return CMO_score
|
||||
|
||||
def calculate_CI_score(results):
|
||||
"""
|
||||
•Correct Instrument (CI): The predicted instrument matches the ground truth instrument.
|
||||
"""
|
||||
correct = 0
|
||||
total = 0
|
||||
for result in results:
|
||||
eval_entry = result['eval_entry']
|
||||
original_entry = result['original_entry']
|
||||
if 'mapped_instruments_summary' in eval_entry and 'instrument_summary' in original_entry:
|
||||
eval_instrument = eval_entry['mapped_instruments_summary'] if isinstance(eval_entry['mapped_instruments'], list) else eval_entry['mapped_instruments']
|
||||
original_instrument = original_entry['instrument_summary']
|
||||
if original_instrument is None or eval_instrument is None:
|
||||
continue
|
||||
# 检查 eval_instrument 是否包含 original_instrument
|
||||
if isinstance(eval_instrument, list):
|
||||
eval_instrument_set = set(eval_instrument)
|
||||
original_instrument_set = set(original_instrument)
|
||||
if original_instrument_set.issubset(eval_instrument_set):
|
||||
correct += 1
|
||||
else:
|
||||
if eval_instrument == original_instrument:
|
||||
correct += 1
|
||||
total += 1
|
||||
CI_score = correct / total if total > 0 else 0
|
||||
print(f"CI Score: {CI_score:.4f} (Correct: {correct}, Total: {total})")
|
||||
return CI_score
|
||||
|
||||
def calculate_CI_top1_score(results):
|
||||
"""
|
||||
•Correct Instrument Top-1 (CI_top1): The predicted instrument matches the ground truth instrument
|
||||
or is one of the top 3 predicted instruments.
|
||||
"""
|
||||
correct = 0
|
||||
total = 0
|
||||
for result in results:
|
||||
eval_entry = result['eval_entry']
|
||||
original_entry = result['original_entry']
|
||||
if 'mapped_instruments_summary' in eval_entry and 'instrument_summary' in original_entry:
|
||||
eval_instrument = eval_entry['mapped_instruments_summary'] if isinstance(eval_entry['mapped_instruments'], list) else eval_entry['mapped_instruments']
|
||||
original_instrument = original_entry['instrument_summary']
|
||||
if original_instrument is None or eval_instrument is None:
|
||||
continue
|
||||
# 检查 eval_instrument 是否包含 original_instrument中的一个元素
|
||||
if isinstance(eval_instrument, list):
|
||||
eval_instrument_set = set(eval_instrument)
|
||||
original_instrument_set = set(original_instrument)
|
||||
for inst in original_instrument_set:
|
||||
if inst in eval_instrument_set:
|
||||
correct += 1
|
||||
break
|
||||
else:
|
||||
if eval_instrument == original_instrument:
|
||||
correct += 1
|
||||
total += 1
|
||||
CI_top1_score = correct / total if total > 0 else 0
|
||||
print(f"CI Top-1 Score: {CI_top1_score:.4f} (Correct: {correct}, Total: {total})")
|
||||
return CI_top1_score
|
||||
|
||||
def calculate_CG_score(results):
|
||||
"""
|
||||
• Correct Genre (CG): The predicted genre matches the ground truth genre.
|
||||
"""
|
||||
correct = 0
|
||||
total = 0
|
||||
for result in results:
|
||||
eval_entry = result['eval_entry']
|
||||
original_entry = result['original_entry']
|
||||
if 'genre' in eval_entry and 'genre' in original_entry:
|
||||
eval_genre = eval_entry['genre'][0] if isinstance(eval_entry['genre'], list) else eval_entry['genre']
|
||||
original_genre = original_entry['genre']
|
||||
if original_genre is None or eval_genre is None:
|
||||
continue
|
||||
# 检查 eval_genre 是否包含 original_genre
|
||||
if isinstance(eval_genre, list):
|
||||
eval_genre_set = set(eval_genre)
|
||||
original_genre_set = set(original_genre)
|
||||
if original_genre_set.issubset(eval_genre_set):
|
||||
correct += 1
|
||||
else:
|
||||
if eval_genre == original_genre:
|
||||
correct += 1
|
||||
total += 1
|
||||
CG_score = correct / total if total > 0 else 0
|
||||
print(f"CG Score: {CG_score:.4f} (Correct: {correct}, Total: {total})")
|
||||
return CG_score
|
||||
|
||||
def calculate_CG_top1_score(results):
|
||||
"""
|
||||
• Correct Genre Top-1 (CG_top1): The predicted genre matches the ground truth genre or is one of the top 3 predicted genres.
|
||||
"""
|
||||
correct = 0
|
||||
total = 0
|
||||
for result in results:
|
||||
eval_entry = result['eval_entry']
|
||||
original_entry = result['original_entry']
|
||||
if 'genre' in eval_entry and 'genre' in original_entry:
|
||||
eval_genre = eval_entry['genre'][0] if isinstance(eval_entry['genre'], list) else eval_entry['genre']
|
||||
original_genre = original_entry['genre']
|
||||
if original_genre is None or eval_genre is None:
|
||||
continue
|
||||
# 检查 eval_genre 是否包含 original_genre中的一个元素
|
||||
if isinstance(eval_genre, list):
|
||||
eval_genre_set = set(eval_genre)
|
||||
original_genre_set = set(original_genre)
|
||||
for gen in original_genre_set:
|
||||
if gen in eval_genre_set:
|
||||
correct += 1
|
||||
break
|
||||
else:
|
||||
if eval_genre == original_genre:
|
||||
correct += 1
|
||||
total += 1
|
||||
CG_top1_score = correct / total if total > 0 else 0
|
||||
print(f"CG Top-1 Score: {CG_top1_score:.4f} (Correct: {correct}, Total: {total})")
|
||||
return CG_top1_score
|
||||
|
||||
def calculate_CM_score(results):
|
||||
"""
|
||||
• Correct Mood (CM): The predicted mood matches the ground truth mood.
|
||||
"""
|
||||
correct = 0
|
||||
total = 0
|
||||
for result in results:
|
||||
eval_entry = result['eval_entry']
|
||||
original_entry = result['original_entry']
|
||||
if 'mood' in eval_entry and 'mood' in original_entry:
|
||||
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
|
||||
original_mood = original_entry['mood']
|
||||
if original_mood is None or eval_mood is None:
|
||||
continue
|
||||
# 检查 eval_mood 是否包含 original_mood
|
||||
if isinstance(eval_mood, list):
|
||||
eval_mood_set = set(eval_mood)
|
||||
original_mood_set = set(original_mood)
|
||||
if original_mood_set.issubset(eval_mood_set):
|
||||
correct += 1
|
||||
else:
|
||||
if eval_mood == original_mood:
|
||||
correct += 1
|
||||
total += 1
|
||||
CM_score = correct / total if total > 0 else 0
|
||||
print(f"CM Score: {CM_score:.4f} (Correct: {correct}, Total: {total})")
|
||||
return CM_score
|
||||
|
||||
def calculate_CM_top1_score(results):
|
||||
"""
|
||||
• Correct Mood Top-1 (CM_top1): The predicted mood matches the ground truth mood or is one of the top 3 predicted moods.
|
||||
"""
|
||||
correct = 0
|
||||
total = 0
|
||||
for result in results:
|
||||
eval_entry = result['eval_entry']
|
||||
original_entry = result['original_entry']
|
||||
if 'mood' in eval_entry and 'mood' in original_entry:
|
||||
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
|
||||
original_mood = original_entry['mood']
|
||||
if original_mood is None or eval_mood is None:
|
||||
continue
|
||||
# 检查 eval_mood 是否包含 original_mood中的一个元素
|
||||
if isinstance(eval_mood, list):
|
||||
eval_mood_set = set(eval_mood)
|
||||
original_mood_set = set(original_mood)
|
||||
for mood in original_mood_set:
|
||||
if mood in eval_mood_set:
|
||||
correct += 1
|
||||
break
|
||||
else:
|
||||
if eval_mood == original_mood:
|
||||
correct += 1
|
||||
total += 1
|
||||
CM_top1_score = correct / total if total > 0 else 0
|
||||
print(f"CM Top-1 Score: {CM_top1_score:.4f} (Correct: {correct}, Total: {total})")
|
||||
return CM_top1_score
|
||||
|
||||
def calculate_CM_top3_score(results):
|
||||
"""
|
||||
• Correct Mood Top-3 (CM_top3): The predicted mood matches the ground truth mood or is one of the top 3 predicted moods.
|
||||
"""
|
||||
correct = 0
|
||||
total = 0
|
||||
for result in results:
|
||||
eval_entry = result['eval_entry']
|
||||
original_entry = result['original_entry']
|
||||
if 'mood' in eval_entry and 'mood' in original_entry:
|
||||
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
|
||||
original_mood = original_entry['mood']
|
||||
if original_mood is None or eval_mood is None:
|
||||
continue
|
||||
# 检查 eval_mood 是否包含 original_mood中的3个元素
|
||||
if isinstance(eval_mood, list):
|
||||
eval_mood_set = set(eval_mood)
|
||||
original_mood_set = set(original_mood)
|
||||
if len(original_mood_set) <= 3 and original_mood_set.issubset(eval_mood_set):
|
||||
correct += 1
|
||||
elif len(original_mood_set) > 3:
|
||||
match_num = sum(1 for mood in original_mood_set if mood in eval_mood_set)
|
||||
if match_num >= 3:
|
||||
correct += 1
|
||||
else:
|
||||
if eval_mood == original_mood:
|
||||
correct += 1
|
||||
total += 1
|
||||
CM_top3_score = correct / total if total > 0 else 0
|
||||
print(f"CM Top-3 Score: {CM_top3_score:.4f} (Correct: {correct}, Total: {total})")
|
||||
return CM_top3_score
|
||||
|
||||
def calculate_all_scores(results):
|
||||
"""
|
||||
Calculate all scores and return them as a dictionary.
|
||||
"""
|
||||
scores = {
|
||||
'TBT_score': calculate_TBT_score(results),
|
||||
'CK_score': calculate_CK_score(results),
|
||||
'CKD_score': calculate_CKD_score(results),
|
||||
'CTS_score': calculate_CTS_score(results),
|
||||
'ECM_score': calculate_ECM_score(results),
|
||||
'CMO_score': calculate_CMO_score(results),
|
||||
'CI_score': calculate_CI_score(results),
|
||||
'CI_top1_score': calculate_CI_top1_score(results),
|
||||
'CG_score': calculate_CG_score(results),
|
||||
'CG_top1_score': calculate_CG_top1_score(results),
|
||||
'CM_score': calculate_CM_score(results),
|
||||
'CM_top1_score': calculate_CM_top1_score(results),
|
||||
'CM_top3_score': calculate_CM_top3_score(results)
|
||||
}
|
||||
return scores
|
||||
if __name__ == "__main__":
|
||||
scores = calculate_all_scores(results)
|
||||
print("All Scores:")
|
||||
for score_name, score_value in scores.items():
|
||||
print(f"{score_name}: {score_value:.4f}")
|
||||
|
||||
# Save the results to a JSON file
|
||||
output_file = f"{generate_path}/results.json"
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(scores, f, indent=4)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
103
SongEval/ebr.py
Normal file
103
SongEval/ebr.py
Normal file
@ -0,0 +1,103 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import pandas as pd
|
||||
import muspy
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from tqdm import tqdm
|
||||
|
||||
def compute_midi_metrics(file_path):
|
||||
"""计算单个MIDI文件的音乐指标"""
|
||||
try:
|
||||
music = muspy.read(file_path)
|
||||
scale_consistency = muspy.scale_consistency(music)
|
||||
pitch_entropy = muspy.pitch_entropy(music)
|
||||
pitch_class_entropy = muspy.pitch_class_entropy(music)
|
||||
empty_beat_rate = muspy.empty_beat_rate(music)
|
||||
groove_consistency = muspy.groove_consistency(music, 12)
|
||||
metrics = {
|
||||
'scale_consistency': scale_consistency,
|
||||
'pitch_entropy': pitch_entropy,
|
||||
'pitch_class_entropy': pitch_class_entropy,
|
||||
'empty_beat_rate': empty_beat_rate,
|
||||
'groove_consistency': groove_consistency,
|
||||
'filename': os.path.basename(file_path)
|
||||
}
|
||||
return metrics
|
||||
except Exception as e:
|
||||
print(f"处理文件 {os.path.basename(file_path)} 时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
def compute_directory_metrics(directory_path, num_workers=8):
|
||||
"""计算目录下所有MIDI文件的音乐指标(多线程加速)"""
|
||||
midi_files = []
|
||||
for root, _, files in os.walk(directory_path):
|
||||
for file in files:
|
||||
if file.lower().endswith(('.mid', '.midi')):
|
||||
midi_files.append(os.path.join(root, file))
|
||||
if not midi_files:
|
||||
print("目录及子文件夹中未找到MIDI文件")
|
||||
return None
|
||||
|
||||
all_metrics = []
|
||||
average_metrics = {
|
||||
'scale_consistency': 0,
|
||||
'pitch_entropy': 0,
|
||||
'pitch_class_entropy': 0,
|
||||
'empty_beat_rate': 0,
|
||||
'groove_consistency': 0
|
||||
}
|
||||
current_num = 0
|
||||
total_scale_consistency = 0
|
||||
total_pitch_entropy = 0
|
||||
total_pitch_class_entropy = 0
|
||||
total_empty_beat_rate = 0
|
||||
total_groove_consistency = 0
|
||||
print(f"正在处理目录: {directory_path}")
|
||||
print(f"发现 {len(midi_files)} 个MIDI文件:")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {executor.submit(compute_midi_metrics, midi_file): midi_file for midi_file in midi_files}
|
||||
for future in tqdm(as_completed(futures), total=len(midi_files), desc="处理中"):
|
||||
metrics = future.result()
|
||||
|
||||
if metrics is not None:
|
||||
current_num += 1
|
||||
total_scale_consistency += metrics['scale_consistency']
|
||||
total_pitch_entropy += metrics['pitch_entropy']
|
||||
total_pitch_class_entropy += metrics['pitch_class_entropy']
|
||||
total_empty_beat_rate += metrics['empty_beat_rate']
|
||||
total_groove_consistency += metrics['groove_consistency']
|
||||
average_metrics['scale_consistency'] = total_scale_consistency / current_num
|
||||
average_metrics['pitch_entropy'] = total_pitch_entropy / current_num
|
||||
average_metrics['pitch_class_entropy'] = total_pitch_class_entropy / current_num
|
||||
average_metrics['empty_beat_rate'] = total_empty_beat_rate / current_num
|
||||
average_metrics['groove_consistency'] = total_groove_consistency / current_num
|
||||
print("current_metrics:", metrics)
|
||||
|
||||
all_metrics.append(metrics)
|
||||
|
||||
if not all_metrics:
|
||||
print("所有文件处理失败")
|
||||
return None
|
||||
|
||||
df = pd.DataFrame(all_metrics)
|
||||
output_csv = os.path.join(directory_path, "midi_metrics_report.csv")
|
||||
df.to_csv(output_csv, index=False)
|
||||
avg_metrics = df.mean(numeric_only=True)
|
||||
return df, avg_metrics
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="计算目录下所有MIDI文件的音乐指标")
|
||||
parser.add_argument("path", type=str, help="包含MIDI文件的目录路径")
|
||||
parser.add_argument("--threads", type=int, default=1, help="线程数(默认8)")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.isdir(args.path):
|
||||
print(f"错误: 路径 '{args.path}' 不存在或不是目录")
|
||||
else:
|
||||
result, averages = compute_directory_metrics(args.path, num_workers=args.threads)
|
||||
if result is not None:
|
||||
print("\n计算完成! 结果已保存到 midi_metrics_report.csv")
|
||||
print("\n平均指标值:")
|
||||
print(averages.to_string())
|
||||
150
SongEval/eval.py
Normal file
150
SongEval/eval.py
Normal file
@ -0,0 +1,150 @@
|
||||
import glob
|
||||
import os
|
||||
import json
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import argparse
|
||||
from muq import MuQ
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import OmegaConf
|
||||
from safetensors.torch import load_file
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class Synthesizer(object):
|
||||
|
||||
def __init__(self,
|
||||
checkpoint_path,
|
||||
input_path,
|
||||
output_dir,
|
||||
use_cpu: bool = False):
|
||||
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.input_path = input_path
|
||||
self.output_dir = output_dir
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
self.device = torch.device('cuda') if (torch.cuda.is_available() and (not use_cpu)) else torch.device('cpu')
|
||||
|
||||
@torch.no_grad()
|
||||
def setup(self):
|
||||
|
||||
train_config = OmegaConf.load(os.path.join(os.path.dirname(self.checkpoint_path), '../config.yaml'))
|
||||
model = instantiate(train_config.generator).to(self.device).eval()
|
||||
state_dict = load_file(self.checkpoint_path, device="cpu")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
self.model = model
|
||||
self.muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
|
||||
self.muq = self.muq.to(self.device).eval()
|
||||
self.result_dcit = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def synthesis(self):
|
||||
if os.path.isfile(self.input_path):
|
||||
if self.input_path.endswith(('.wav', '.mp3')):
|
||||
lines = []
|
||||
lines.append(self.input_path)
|
||||
else:
|
||||
with open(self.input_path, "r") as f:
|
||||
lines = [line for line in f]
|
||||
input_files = [{
|
||||
"input_path": line.strip(),
|
||||
} for line in lines]
|
||||
print(f"input filelst: {self.input_path}")
|
||||
elif os.path.isdir(self.input_path):
|
||||
input_files = [{
|
||||
"input_path": file,
|
||||
}for file in glob.glob(os.path.join(self.input_path, '*')) if file.lower().endswith(('.wav', '.mp3'))]
|
||||
else:
|
||||
raise ValueError(f"input_path {self.input_path} is not a file or directory")
|
||||
|
||||
|
||||
for input in tqdm(input_files):
|
||||
try:
|
||||
self.handle(**input)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
# add average
|
||||
avg_values = {}
|
||||
for key in self.result_dcit[list(self.result_dcit.keys())[0]].keys():
|
||||
avg_values[key] = round(np.mean([self.result_dcit[fid][key] for fid in self.result_dcit]), 4)
|
||||
self.result_dcit['average'] = avg_values
|
||||
# save result
|
||||
with open(os.path.join(self.output_dir, "result.json") , "w")as f:
|
||||
json.dump(self.result_dcit, f, indent=4, ensure_ascii=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def handle(self, input_path):
|
||||
|
||||
fid = os.path.basename(input_path).split('.')[0]
|
||||
if input_path.endswith('.npy'):
|
||||
input = np.load(input_path)
|
||||
|
||||
# check ssl
|
||||
if len(input.shape) == 3 and input.shape[0] != 1:
|
||||
print('ssl_shape error', input_path)
|
||||
return
|
||||
if np.isnan(input).any():
|
||||
print('ssl nan', input_path)
|
||||
return
|
||||
|
||||
input = torch.from_numpy(input).to(self.device)
|
||||
if len(input.shape) == 2:
|
||||
input = input.unsqueeze(0)
|
||||
|
||||
if input_path.endswith(('.wav', '.mp3')):
|
||||
wav, sr = librosa.load(input_path, sr=24000)
|
||||
audio = torch.tensor(wav).unsqueeze(0).to(self.device)
|
||||
output = self.muq(audio, output_hidden_states=True)
|
||||
input = output["hidden_states"][6]
|
||||
|
||||
values = {}
|
||||
scores_g = self.model(input).squeeze(0)
|
||||
values['Coherence'] = round(scores_g[0].item(), 4)
|
||||
values['Musicality'] = round(scores_g[1].item(), 4)
|
||||
values['Memorability'] = round(scores_g[2].item(), 4)
|
||||
values['Clarity'] = round(scores_g[3].item(), 4)
|
||||
values['Naturalness'] = round(scores_g[4].item(), 4)
|
||||
|
||||
|
||||
self.result_dcit[fid] = values
|
||||
# delete
|
||||
del input, output, scores_g, values,audio, wav, sr
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-i", "--input_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Input audio: path to a single file, a text file listing audio paths, or a directory of audio files."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output directory for generated results (will be created if it doesn't exist)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cpu",
|
||||
type=str,
|
||||
help="Force CPU mode even if a GPU is available.",
|
||||
default=False
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ckpt_path = "ckpt/model.safetensors"
|
||||
|
||||
synthesizer = Synthesizer(checkpoint_path=ckpt_path,
|
||||
input_path=args.input_path,
|
||||
output_dir=args.output_dir,
|
||||
use_cpu=args.use_cpu)
|
||||
|
||||
synthesizer.setup()
|
||||
|
||||
synthesizer.synthesis()
|
||||
404
SongEval/generate-batch_easy.py
Normal file
404
SongEval/generate-batch_easy.py
Normal file
@ -0,0 +1,404 @@
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from multiprocessing import Process,set_start_method
|
||||
import torch
|
||||
import argparse
|
||||
from omegaconf import OmegaConf
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
from Amadeus.evaluation_utils import (
|
||||
wandb_style_config_to_omega_config,
|
||||
prepare_model_and_dataset_from_config,
|
||||
get_best_ckpt_path_and_config,
|
||||
Evaluator
|
||||
)
|
||||
from transformers import T5Tokenizer, T5EncoderModel
|
||||
|
||||
from Amadeus import model_zoo
|
||||
from Amadeus.symbolic_encoding import data_utils
|
||||
from Amadeus.model_zoo import AmadeusModel
|
||||
from Amadeus.symbolic_encoding.data_utils import TuneCompiler
|
||||
from Amadeus.symbolic_encoding.compile_utils import shift_and_pad
|
||||
from Amadeus.symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor
|
||||
from Amadeus.symbolic_encoding import decoding_utils
|
||||
from Amadeus.train_utils import adjust_prediction_order
|
||||
from data_representation import vocab_utils
|
||||
from data_representation.vocab_utils import LangTokenVocab
|
||||
|
||||
|
||||
def get_argument_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-wandb_exp_dir",
|
||||
required=True,
|
||||
type=str,
|
||||
help="wandb experiment directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-generation_type",
|
||||
type=str,
|
||||
choices=('conditioned', 'unconditioned', 'text-conditioned'),
|
||||
default='unconditioned',
|
||||
help="generation type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-sampling_method",
|
||||
type=str,
|
||||
choices=('top_p', 'top_k'),
|
||||
default='top_p',
|
||||
help="sampling method",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-threshold",
|
||||
type=float,
|
||||
default=0.99,
|
||||
help="threshold",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-temperature",
|
||||
type=float,
|
||||
default=1.15,
|
||||
help="temperature",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-num_samples",
|
||||
type=int,
|
||||
default=30,
|
||||
help="number of samples to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-num_target_measure",
|
||||
type=int,
|
||||
default=4,
|
||||
help="number of target measures for conditioned generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-choose_selected_tunes",
|
||||
action='store_true',
|
||||
help="generate samples from selected tunes, only for SOD dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-generate_length",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="length of the generated sequence",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-num_processes",
|
||||
type=int,
|
||||
default=2,
|
||||
help="number of processes to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-gpu_ids",
|
||||
type=str,
|
||||
default="0,5",
|
||||
help="comma-separated list of GPU IDs to use (e.g., '0,1,2,3')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-prompt",
|
||||
type=str,
|
||||
default="With a rhythm of 100 BPM, this classical piece in 1/4 time signature in the key of Eb major creates a classical mood using String Ensemble, Pizzicato Strings, Tremolo Strings, Trumpet, Timpani.",
|
||||
help="prompt for generation, only used for conditioned generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-prompt_file",
|
||||
type=str,
|
||||
default="dataset/midicaps/train.json",
|
||||
help="file containing prompts for text-conditioned generation",
|
||||
)
|
||||
return parser
|
||||
|
||||
def load_resources(wandb_exp_dir, device):
|
||||
"""Load model and dataset resources for a process"""
|
||||
wandb_dir = Path('wandb')
|
||||
ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir)
|
||||
config = OmegaConf.load(config_path)
|
||||
config = wandb_style_config_to_omega_config(config)
|
||||
|
||||
# Load checkpoint to specified device
|
||||
ckpt = torch.load(ckpt_path, map_location=device)
|
||||
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path)
|
||||
model.load_state_dict(ckpt['model'], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
torch.compile(model)
|
||||
print("total parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
|
||||
|
||||
# Prepare dataset for prompts
|
||||
condition_list = [x[1] for x in test_set.data_list]
|
||||
dataset_for_prompt = []
|
||||
for i in range(len(condition_list)):
|
||||
condition = test_set.get_segments_with_tune_idx(condition_list[i], 0)[0]
|
||||
dataset_for_prompt.append((condition, condition_list[i]))
|
||||
|
||||
return config, model, dataset_for_prompt, vocab
|
||||
|
||||
def conditioned_worker(process_idx, gpu_id, args, data_slice):
|
||||
"""Worker process for conditioned generation"""
|
||||
torch.cuda.set_device(gpu_id)
|
||||
device = torch.device(f'cuda:{gpu_id}')
|
||||
|
||||
# Load resources with proper device
|
||||
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
||||
|
||||
# Create output directory with process index
|
||||
base_path = Path('wandb') / args.wandb_exp_dir / \
|
||||
f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device)
|
||||
|
||||
# Process assigned data slice
|
||||
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
|
||||
batch_dir = base_path / f"process_{process_idx}_batch_{idx}"
|
||||
batch_dir.mkdir(parents=True, exist_ok=True)
|
||||
evaluator.generate_samples_with_prompt(
|
||||
batch_dir,
|
||||
args.num_target_measure,
|
||||
tune_in_idx,
|
||||
tune_name,
|
||||
config.data_params.first_pred_feature,
|
||||
args.sampling_method,
|
||||
args.threshold,
|
||||
args.temperature,
|
||||
generation_length=args.generate_length
|
||||
)
|
||||
def generate_samples_unconditioned(config, vocab, model, device,save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
|
||||
encoding_scheme = config.nn_params.encoding_scheme
|
||||
|
||||
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||
try:
|
||||
in_beat_resolution = in_beat_resolution_dict[config.dataset]
|
||||
except KeyError:
|
||||
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||
decoder = getattr(decoding_utils, decoder_name)(vocab=vocab, in_beat_resolution=in_beat_resolution, dataset_name=config.dataset)
|
||||
|
||||
for i in range(num_samples):
|
||||
generated_sample = model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
|
||||
if encoding_scheme == 'nb':
|
||||
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||
decoder(generated_sample, output_path=str(save_dir / f"{uid}_{i}.mid"))
|
||||
|
||||
def generate_samples_with_text_prompt(config, vocab, model, device, save_dir, prompt, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
|
||||
encoding_scheme = config.nn_params.encoding_scheme
|
||||
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-large')
|
||||
encoder = T5EncoderModel.from_pretrained('google/flan-t5-large').to(device)
|
||||
print(f"Using T5EncoderModel for text prompt: {prompt}")
|
||||
context = tokenizer(prompt, return_tensors='pt', padding='max_length', truncation=True, max_length=128).to(device)
|
||||
context = encoder(**context).last_hidden_state
|
||||
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
|
||||
try:
|
||||
in_beat_resolution = in_beat_resolution_dict[config.dataset]
|
||||
except KeyError:
|
||||
in_beat_resolution = 4 # Default resolution if dataset is not found
|
||||
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
|
||||
decoder_name = midi_decoder_dict[encoding_scheme]
|
||||
decoder = getattr(decoding_utils, decoder_name)(vocab=vocab, in_beat_resolution=in_beat_resolution, dataset_name=config.dataset)
|
||||
|
||||
generated_sample = model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
|
||||
if encoding_scheme == 'nb':
|
||||
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
|
||||
# Open the jsonl file and count the number of lines to determine the current index
|
||||
jsonl_path = save_dir / "name2prompt.jsonl"
|
||||
if jsonl_path.exists():
|
||||
with open(jsonl_path, 'r') as f:
|
||||
current_idx = sum(1 for _ in f)
|
||||
else:
|
||||
current_idx = 0
|
||||
|
||||
name = f"prompt_{current_idx}"
|
||||
name2prompt_dict = defaultdict(list)
|
||||
name2prompt_dict[name].append(prompt)
|
||||
with open(jsonl_path, 'a') as f:
|
||||
f.write(json.dumps(name2prompt_dict) + '\n')
|
||||
decoder(generated_sample, output_path=str(save_dir / f"{name}_{uid}.mid"))
|
||||
|
||||
def unconditioned_worker(process_idx, gpu_id, args, num_samples):
|
||||
"""Worker process for unconditioned generation"""
|
||||
torch.cuda.set_device(gpu_id)
|
||||
device = torch.device(f'cuda:{gpu_id}')
|
||||
|
||||
# Load resources with proper device
|
||||
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
||||
|
||||
# Create output directory with process index
|
||||
base_path = Path('wandb') / args.wandb_exp_dir / \
|
||||
f"uncond_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate assigned number of samples
|
||||
batch_dir = base_path
|
||||
generate_samples_unconditioned(
|
||||
config,
|
||||
vocab,
|
||||
model,
|
||||
batch_dir,
|
||||
num_samples,
|
||||
config.data_params.first_pred_feature,
|
||||
args.sampling_method,
|
||||
args.threshold,
|
||||
args.temperature,
|
||||
generation_length=args.generate_length,
|
||||
uid=f"{process_idx}"
|
||||
)
|
||||
def text_conditioned_worker(process_idx, gpu_id, args, num_samples, data_slice):
|
||||
"""Worker process for unconditioned generation"""
|
||||
torch.cuda.set_device(gpu_id)
|
||||
device = torch.device(f'cuda:{gpu_id}')
|
||||
|
||||
# Load resources with proper device
|
||||
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
|
||||
|
||||
# Create output directory with process index
|
||||
base_path = Path('wandb') / args.wandb_exp_dir / \
|
||||
f"text_condi_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
|
||||
base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate assigned number of samples
|
||||
batch_dir = base_path
|
||||
for idx, tune_name in enumerate(data_slice):
|
||||
print(f"Process {process_idx} generating samples for tune: {tune_name}")
|
||||
generate_samples_with_text_prompt(
|
||||
config,
|
||||
vocab,
|
||||
model,
|
||||
device,
|
||||
batch_dir,
|
||||
prompt=tune_name,
|
||||
first_pred_feature=config.data_params.first_pred_feature,
|
||||
sampling_method=args.sampling_method,
|
||||
threshold=args.threshold,
|
||||
temperature=args.temperature,
|
||||
generation_length=args.generate_length,
|
||||
uid=f"{process_idx}_{idx}"
|
||||
)
|
||||
def main():
|
||||
# use spawn method for multiprocessing
|
||||
set_start_method('spawn', force=True)
|
||||
args = get_argument_parser().parse_args()
|
||||
gpu_ids = list(map(int, args.gpu_ids.split(',')))
|
||||
|
||||
# Validate GPU availability
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is not available")
|
||||
if len(gpu_ids) == 0:
|
||||
raise ValueError("At least one GPU must be specified")
|
||||
|
||||
# Validate process count
|
||||
if args.num_processes < 1:
|
||||
raise ValueError("Number of processes must be at least 1")
|
||||
if len(gpu_ids) < args.num_processes:
|
||||
print(f"Warning: More processes ({args.num_processes}) than GPUs ({len(gpu_ids)}), some GPUs will be shared")
|
||||
|
||||
# Prepare data slices for processes
|
||||
processes = []
|
||||
try:
|
||||
if args.generation_type == 'conditioned':
|
||||
# Prepare selected tunes
|
||||
wandb_dir = Path('wandb') / args.wandb_exp_dir
|
||||
if not wandb_dir.exists():
|
||||
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
|
||||
|
||||
# Load test set to get selected tunes (dummy load to get dataset info)
|
||||
dummy_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
_, test_set, _ = prepare_model_and_dataset_from_config(
|
||||
wandb_dir / "files" / "config.yaml",
|
||||
wandb_dir / "files" / "metadata.json",
|
||||
wandb_dir / "files" / "vocab.json"
|
||||
)
|
||||
|
||||
if args.choose_selected_tunes and test_set.dataset == 'SOD':
|
||||
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
|
||||
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
|
||||
else:
|
||||
selected_tunes = [name for _, name in test_set.data_list][:args.num_samples]
|
||||
|
||||
# Split selected data across processes
|
||||
selected_data = [d for d in test_set.data_list if d[1] in selected_tunes]
|
||||
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
|
||||
|
||||
for i in range(args.num_processes):
|
||||
start_idx = i * chunk_size
|
||||
end_idx = min((i+1)*chunk_size, len(selected_data))
|
||||
data_slice = selected_data[start_idx:end_idx]
|
||||
|
||||
if not data_slice:
|
||||
continue
|
||||
|
||||
gpu_id = gpu_ids[i % len(gpu_ids)]
|
||||
p = Process(
|
||||
target=conditioned_worker,
|
||||
args=(i, gpu_id, args, data_slice)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
elif args.generation_type == 'unconditioned':
|
||||
samples_per_proc = args.num_samples // args.num_processes
|
||||
remainder = args.num_samples % args.num_processes
|
||||
|
||||
for i in range(args.num_processes):
|
||||
gpu_id = gpu_ids[i % len(gpu_ids)]
|
||||
samples = samples_per_proc + (1 if i < remainder else 0)
|
||||
|
||||
if samples <= 0:
|
||||
continue
|
||||
|
||||
p = Process(
|
||||
target=unconditioned_worker,
|
||||
args=(i, gpu_id, args, samples)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
elif args.generation_type == 'text-conditioned':
|
||||
samples_per_proc = args.num_samples // args.num_processes
|
||||
remainder = args.num_samples % args.num_processes
|
||||
# Load prompts from file
|
||||
prompt_name_list = []
|
||||
with open(args.prompt_file, 'r') as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
prompt_data = json.loads(line.strip())
|
||||
prompt_text = prompt_data['caption']
|
||||
if prompt_data['test_set'] is True:
|
||||
prompt_name_list.append(prompt_text)
|
||||
print("length of prompt_name_list:", len(prompt_name_list))
|
||||
if len(prompt_name_list) >= args.num_samples:
|
||||
print(f"Reached the limit of {args.num_samples} prompts.")
|
||||
break
|
||||
for i in range(args.num_processes):
|
||||
gpu_id = gpu_ids[i % len(gpu_ids)]
|
||||
samples = samples_per_proc + (1 if i < remainder else 0)
|
||||
|
||||
if samples <= 0:
|
||||
continue
|
||||
|
||||
# Split prompt names across processes
|
||||
start_idx = i * (len(prompt_name_list) // args.num_processes)
|
||||
end_idx = (i + 1) * (len(prompt_name_list) // args.num_processes)
|
||||
data_slice = prompt_name_list[start_idx:end_idx]
|
||||
|
||||
p = Process(
|
||||
target=text_conditioned_worker,
|
||||
args=(i, gpu_id, args, samples, data_slice)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
# Wait for all processes to complete
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in main process: {str(e)}")
|
||||
for p in processes:
|
||||
p.terminate()
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
68
SongEval/matrics.py
Normal file
68
SongEval/matrics.py
Normal file
@ -0,0 +1,68 @@
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import torch
|
||||
from audioldm_eval import EvaluationHelper, EvaluationHelperParallel
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--generation_path", type=str, required=True, help="Path to generated audio files")
|
||||
parser.add_argument("--target_path", type=str, required=True, help="Path to reference audio files")
|
||||
parser.add_argument("--force_paired", action="store_true", help="Force pairing by randomly selecting reference files")
|
||||
parser.add_argument("--gpu_mode", choices=["single", "multi"], default="single", help="Evaluation mode")
|
||||
parser.add_argument("--num_gpus", type=int, default=2, help="Number of GPUs for multi-GPU mode")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle forced pairing
|
||||
target_eval_path = args.target_path
|
||||
temp_dir = None
|
||||
if args.force_paired:
|
||||
print(f"Using forced pairing with reference files from {args.target_path}")
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
target_eval_path = temp_dir
|
||||
|
||||
# Collect generated filenames
|
||||
gen_files = []
|
||||
for root, _, files in os.walk(args.generation_path):
|
||||
for file in files:
|
||||
if file.endswith(".wav"):
|
||||
gen_files.append(file)
|
||||
print(f"Found {len(gen_files)} generated files in {args.generation_path}")
|
||||
# Collect all reference files
|
||||
ref_files = []
|
||||
for root, _, files in os.walk(args.target_path):
|
||||
for file in files:
|
||||
if file.endswith(".wav"):
|
||||
ref_files.append(os.path.join(root, file))
|
||||
|
||||
# Select random references matching the count
|
||||
selected_refs = np.random.choice(ref_files, len(gen_files), replace=False)
|
||||
print(f"Selected {len(selected_refs)} reference files for evaluation.")
|
||||
# Copy selected references to temp dir with generated filenames
|
||||
for gen_file, ref_path in zip(gen_files, selected_refs):
|
||||
shutil.copy(ref_path, os.path.join(temp_dir, gen_file))
|
||||
|
||||
|
||||
device = torch.device(f"cuda:{0}") if args.gpu_mode == "single" else None
|
||||
|
||||
try:
|
||||
if args.gpu_mode == "single":
|
||||
print("Running single GPU evaluation...")
|
||||
evaluator = EvaluationHelper(16000, device)
|
||||
metrics = evaluator.main(args.generation_path, target_eval_path)
|
||||
else:
|
||||
print(f"Running multi-GPU evaluation on {args.num_gpus} GPUs...")
|
||||
evaluator = EvaluationHelperParallel(16000, args.num_gpus)
|
||||
metrics = evaluator.main(args.generation_path, target_eval_path)
|
||||
print("Evaluation completed.")
|
||||
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
if temp_dir and os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
66
SongEval/model.py
Normal file
66
SongEval/model.py
Normal file
@ -0,0 +1,66 @@
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
ffd_hidden_size,
|
||||
num_classes,
|
||||
attn_layer_num,
|
||||
|
||||
):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
self.attn = nn.ModuleList(
|
||||
[
|
||||
nn.MultiheadAttention(
|
||||
embed_dim=in_features,
|
||||
num_heads=8,
|
||||
dropout=0.2,
|
||||
batch_first=True,
|
||||
)
|
||||
for _ in range(attn_layer_num)
|
||||
]
|
||||
)
|
||||
|
||||
self.ffd = nn.Sequential(
|
||||
nn.Linear(in_features, ffd_hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(ffd_hidden_size, in_features)
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(0.2)
|
||||
|
||||
self.fc = nn.Linear(in_features * 2, num_classes)
|
||||
|
||||
self.proj = nn.Tanh()
|
||||
|
||||
|
||||
def forward(self, ssl_feature, judge_id=None):
|
||||
'''
|
||||
ssl_feature: [B, T, D]
|
||||
output: [B, num_classes]
|
||||
'''
|
||||
|
||||
B, T, D = ssl_feature.shape
|
||||
|
||||
ssl_feature = self.ffd(ssl_feature)
|
||||
|
||||
tmp_ssl_feature = ssl_feature
|
||||
|
||||
for attn in self.attn:
|
||||
tmp_ssl_feature, _ = attn(tmp_ssl_feature, tmp_ssl_feature, tmp_ssl_feature)
|
||||
|
||||
ssl_feature = self.dropout(torch.concat([torch.mean(tmp_ssl_feature, dim=1), torch.max(ssl_feature, dim=1)[0]], dim=1)) # B, 2D
|
||||
|
||||
x = self.fc(ssl_feature) # B, num_classes
|
||||
|
||||
x = self.proj(x) * 2.0 + 3
|
||||
|
||||
return x
|
||||
|
||||
|
||||
4
SongEval/requirements.txt
Normal file
4
SongEval/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
librosa==0.11.0
|
||||
torch==2.7.0
|
||||
muq==0.1.0
|
||||
hydra-core==1.3.2
|
||||
3397
SongEval/result.json
Normal file
3397
SongEval/result.json
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user