first commit

This commit is contained in:
2025-09-08 14:49:28 +08:00
commit 80333dff74
160 changed files with 30655 additions and 0 deletions

BIN
SongEval/.DS_Store vendored Normal file

Binary file not shown.

201
SongEval/LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

88
SongEval/README.md Normal file
View File

@ -0,0 +1,88 @@
# 🎵 SongEval: A Benchmark Dataset for Song Aesthetics Evaluation
[![Hugging Face Dataset](https://img.shields.io/badge/HuggingFace-Dataset-blue)](https://huggingface.co/datasets/ASLP-lab/SongEval)
[![Arxiv Paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/pdf/2505.10793)
[![License: CC BY-NC-SA 4.0](https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg)](https://creativecommons.org/licenses/by-nc-sa/4.0/)
This repository provides a **trained aesthetic evaluation toolkit** based on [SongEval](https://huggingface.co/datasets/ASLP-lab/SongEval), the first large-scale, open-source dataset for human-perceived song aesthetics. The toolkit enables **automatic scoring of generated song** across five perceptual aesthetic dimensions aligned with professional musician judgments.
---
## 🌟 Key Features
- 🧠 **Pretrained neural models** for perceptual aesthetic evaluation
- 🎼 Predicts **five aesthetic dimensions**:
- Overall Coherence
- Memorability
- Naturalness of Vocal Breathing and Phrasing
- Clarity of Song Structure
- Overall Musicality
<!-- - 🧪 Supports **batch evaluation** for model benchmarking -->
- 🎧 Accepts **full-length songs** (vocals + accompaniment) as input
- ⚙️ Simple inference interface
---
## 📦 Installation
Clone the repository and install dependencies:
```bash
git clone https://github.com/ASLP-lab/SongEval.git
cd SongEval
pip install -r requirements.txt
```
## 🚀 Quick Start
- Evaluate a single audio file:
```bash
python eval.py -i /path/to/audio.mp3 -o /path/to/output
```
- Evaluate a list of audio files:
```bash
python eval.py -i /path/to/audio_list.txt -o /path/to/output
```
- Evaluate all audio files in a directory:
```bash
python eval.py -i /path/to/audio_directory -o /path/to/output
```
- Force evaluation on CPU (⚠️ CPU evaluation may be significantly slower) :
```bash
python eval.py -i /path/to/audio.wav -o /path/to/output --use_cpu True
```
## 🙏 Acknowledgement
This project is mainly organized by the audio, speech and language processing lab [(ASLP@NPU)](http://www.npu-aslp.org/).
We sincerely thank the **Shanghai Conservatory of Music** for their expert guidance on music theory, aesthetics, and annotation design.
Meanwhile, we thank AISHELL to help with the orgnization of the song annotations.
<p align="center"> <img src="assets/logo.png" alt="Shanghai Conservatory of Music Logo"/> </p>
## 📑 License
This project is released under the CC BY-NC-SA 4.0 license.
You are free to use, modify, and build upon it for non-commercial purposes, with attribution.
## 📚 Citation
If you use this toolkit or the SongEval dataset, please cite the following:
```
@article{yao2025songeval,
title = {SongEval: A Benchmark Dataset for Song Aesthetics Evaluation},
author = {Yao, Jixun and Ma, Guobin and Xue, Huixin and Chen, Huakang and Hao, Chunbo and Jiang, Yuepeng and Liu, Haohe and Yuan, Ruibin and Xu, Jin and Xue, Wei and others},
journal = {arXiv preprint arXiv:2505.10793},
year={2025}
}
```

BIN
SongEval/assets/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1016 KiB

184
SongEval/clap_score.py Normal file
View File

@ -0,0 +1,184 @@
import os
import requests
from tqdm import tqdm
import torch
import numpy as np
import laion_clap
from clap_module.factory import load_state_dict
import librosa
import pyloudnorm as pyln
# following documentation from https://github.com/LAION-AI/CLAP
def int16_to_float32(x):
return (x / 32767.0).astype(np.float32)
def float32_to_int16(x):
x = np.clip(x, a_min=-1., a_max=1.)
return (x * 32767.).astype(np.int16)
def clap_score(id2text, audio_path, audio_files_extension='.wav', clap_model='music_speech_audioset_epoch_15_esc_89.98.pt'):
"""
Cosine similarity is computed between the LAION-CLAP text embedding of the given prompt and
the LAION-CLAP audio embedding of the generated audio. LION-CLAP: https://github.com/LAION-AI/CLAP
This evaluation script assumes that audio_path files are identified with the ids in id2text.
clap_score() evaluates all ids in id2text.
GPU-based computation.
Select one of the following models from https://github.com/LAION-AI/CLAP:
- music_speech_audioset_epoch_15_esc_89.98.pt (used by musicgen)
- music_audioset_epoch_15_esc_90.14.pt
- music_speech_epoch_15_esc_89.25.pt
- 630k-audioset-fusion-best.pt (our default, with "fusion" to handle longer inputs)
Params:
-- id2text: dictionary with the mapping between id (generated audio filenames in audio_path)
and text (prompt used to generate audio). clap_score() evaluates all ids in id2text.
-- audio_path: path where the generated audio files to evaluate are available.
-- audio_files_extension: files extension (default .wav) in eval_path.
-- clap_model: choose one of the above clap_models (default: '630k-audioset-fusion-best.pt').
Returns:
-- CLAP-LION score
"""
# load model
if clap_model == 'music_speech_audioset_epoch_15_esc_89.98.pt':
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_audioset_epoch_15_esc_89.98.pt'
clap_path = 'load/clap_score/music_speech_audioset_epoch_15_esc_89.98.pt'
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
elif clap_model == 'music_audioset_epoch_15_esc_90.14.pt':
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_audioset_epoch_15_esc_90.14.pt'
clap_path = 'load/clap_score/music_audioset_epoch_15_esc_90.14.pt'
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
elif clap_model == 'music_speech_epoch_15_esc_89.25.pt':
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_epoch_15_esc_89.25.pt'
clap_path = 'load/clap_score/music_speech_epoch_15_esc_89.25.pt'
model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cuda')
elif clap_model == '630k-audioset-fusion-best.pt':
url = 'https://huggingface.co/lukewys/laion_clap/resolve/main/630k-audioset-fusion-best.pt'
clap_path = 'load/clap_score/630k-audioset-fusion-best.pt'
model = laion_clap.CLAP_Module(enable_fusion=True, device='cuda')
else:
raise ValueError('clap_model not implemented')
# download clap_model if not already downloaded
if not os.path.exists(clap_path):
print('Downloading ', clap_model, '...')
os.makedirs(os.path.dirname(clap_path), exist_ok=True)
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(clap_path, 'wb') as file:
with tqdm(total=total_size, unit='B', unit_scale=True) as progress_bar:
for data in response.iter_content(chunk_size=8192):
file.write(data)
progress_bar.update(len(data))
# fixing CLAP-LION issue, see: https://github.com/LAION-AI/CLAP/issues/118
pkg = load_state_dict(clap_path)
pkg.pop('text_branch.embeddings.position_ids', None)
model.model.load_state_dict(pkg)
model.eval()
if not os.path.isdir(audio_path):
raise ValueError('audio_path does not exist')
if id2text:
print('[EXTRACTING TEXT EMBEDDINGS] ')
batch_size = 64
text_emb = {}
for i in tqdm(range(0, len(id2text), batch_size)):
batch_ids = list(id2text.keys())[i:i+batch_size]
batch_texts = [id2text[id] for id in batch_ids]
with torch.no_grad():
embeddings = model.get_text_embedding(batch_texts, use_tensor=True)
for id, emb in zip(batch_ids, embeddings):
text_emb[id] = emb
else:
raise ValueError('Must specify id2text')
print('[EVALUATING GENERATIONS] ', audio_path)
score = 0
count = 0
for id in tqdm(id2text.keys()):
file_path = os.path.join(audio_path, str(id)+audio_files_extension)
with torch.no_grad():
audio, _ = librosa.load(file_path, sr=48000, mono=True) # sample rate should be 48000
audio = pyln.normalize.peak(audio, -1.0)
audio = audio.reshape(1, -1) # unsqueeze (1,T)
audio = torch.from_numpy(int16_to_float32(float32_to_int16(audio))).float()
audio_embeddings = model.get_audio_embedding_from_data(x = audio, use_tensor=True)
cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_emb[id].unsqueeze(0), dim=1, eps=1e-8)[0]
score += cosine_sim
count += 1
return score / count if count > 0 else 0
if __name__ == "__main__":
import pandas as pd
import json
import argparse
parser = argparse.ArgumentParser(description='Compute CLAP score for generated audio files.')
parser.add_argument('--clap_model', type=str, default='630k-audioset-fusion-best.pt',
help='CLAP model to use for evaluation. Options: music_speech_audioset_epoch_15_esc_89.98.pt, music_audioset_epoch_15_esc_90.14.pt, music_speech_epoch_15_esc_89.25.pt, 630k-audioset-fusion-best.pt (default: 630k-audioset-fusion-best.pt)')
parser.add_argument('--root_path', type=str, default='../wandb/run-20250627_172105-xpe7nh5n-worseInstr/generated_samples_text_conditioned_top_p_threshold_0.99_temperature_1.15_8',
help='Path to the directory containing generated audio files and id2text mapping.')
args = parser.parse_args()
clap_model = args.clap_model
root_path = args.root_path
json_file_path = os.path.join(root_path, 'name2prompt.jsonl')
generated_path = os.path.join(root_path, 'prompt_music')
if not os.path.exists(generated_path):
generated_path =root_path # if no 'music' subfolder, use root_path directly
with open(json_file_path, 'r') as f:
id2text_dict = {}
for line in f:
item = json.loads(line)
for k, v in item.items():
id2text_dict[k] = v[0]
print('length of id2text:', len(id2text_dict))
# id2text = {k+'_1': v[0] for k, v in id2text_dict.items()} # assuming each key has a list of prompts, we take the first one
id2text ={}
for k, v in id2text_dict.items():
if isinstance(v, list):
id2text[k] = v[0]
# ckeck if k exist as wav file
if os.path.exists(os.path.join(generated_path, str(k)+'.wav')):
id2text[k] = v[0]
else:
# find k_*, k_1, k_2, ... and check if they exist
for i in range(0, 10): # assuming no more than 100 variations
if os.path.exists(os.path.join(generated_path, str(k)+'_'+str(i)+'.wav')):
new_key = str(k) + '_' + str(i)
id2text[new_key] = v[0]
print('length of id2text after checking wav files:', len(id2text))
# check if wav exsists
new_id2text = {}
for id in id2text.keys():
file_path = os.path.join(generated_path, str(id)+'.wav')
if os.path.exists(file_path):
new_id2text[id] = id2text[id]
else:
print(f"Warning: {file_path} does not exist, skipping this id.")
print('length of new_id2text:', len(new_id2text))
"""
IMPORTANT: the audios in generated_path should have the same ids as in id2text.
For musiccaps, you can load id2text as above and each generated_path audio file
corresponds to a prompt (text description) in musiccaps. Files are named with ids, as follows:
- your_model_outputs_folder/_-kssA-FOzU.wav
- your_model_outputs_folder/_0-2meOf9qY.wav
- your_model_outputs_folder/_1woPC5HWSg.wav
...
- your_model_outputs_folder/ZzyWbehtt0M.wav
"""
clp = clap_score(new_id2text, generated_path, audio_files_extension='.wav')
print('CLAP score (cosine similarity):', clp)

6
SongEval/config.yaml Normal file
View File

@ -0,0 +1,6 @@
generator:
_target_: model.Generator
in_features: 1024
ffd_hidden_size: 4096
num_classes: 5
attn_layer_num: 4

456
SongEval/controlability.py Normal file
View File

@ -0,0 +1,456 @@
import json
generate_path = 'Text2midi/muzic/musecoco/2-attribute2music_model/generation/0505/linear_mask-1billion-attribute2music/infer_test/topk15-t1.0-ngram0/all_midis'
# generate_path = 'Text2midi/t2m-inferalign/text2midi_infer_output'
# generate_path = 'wandb/no-disp-no-ciem/text_condi_top_p_t0.99_temp1.25'
test_set_json = "dataset/midicaps/train.json"
generated_eval_json_path = f"{generate_path}/eval.json"
generated_name2prompt_jsonl_path = f"{generate_path}/name2prompt.jsonl"
# 1. 读取 test_set建立 prompt 到条目的映射
with open(test_set_json, 'r') as f:
test_set =[]
for line in f:
if not line.strip():
continue
item = json.loads(line.strip())
test_set.append(item)
prompt2item = {item['caption']: item for item in test_set if item['test_set'] is True}
print(f"Number of prompts in test set: {len(prompt2item)}")
# 2. 读取 name2prompt.jsonl建立 name 到 prompt 的映射
name2prompt = {}
with open(generated_name2prompt_jsonl_path, 'r') as f:
for line in f:
obj = json.loads(line)
name2prompt.update({k: v[0] for k, v in obj.items() if isinstance(v, list) and len(v) > 0})
# 3. 读取 eval.json
with open(generated_eval_json_path, 'r') as f:
eval_items = []
for line in f:
if not line.strip():
continue
item = json.loads(line.strip())
eval_items.append(item)
# 4. 对每个 name找到对应的 prompt确保 prompt 在 test_set 里,然后找到 eval.json 里对应的条目
results = []
# turn the name of eval_items into relative name
for item in eval_items:
item['name'] = item['name'].split('/')[-1] # 假设 name 是一个路径,取最后一部分作为相对名称
# 去掉第二个下划线后面的内容
if '_' in item['name']:
item['name'] = item['name'].split('.')[0].split('_')[0] + '_' + item['name'].split('.')[0].split('_')[1]
# print(f"Processed eval item name: {item['name']}")
for name, prompt in name2prompt.items():
if prompt not in prompt2item:
print(f"Prompt not found in test set: {prompt}")
continue
# 找到 eval.json 里对应的条目(假设 eval.json 里有 name 字段)
eval_entry = next((item for item in eval_items if item.get('name') == name), None)
if eval_entry is None:
print(f"Eval entry not found for name: {name}")
continue
# 原始条目
original_entry = prompt2item[prompt]
results.append({
'name': name,
'prompt': prompt,
'eval_entry': eval_entry,
'original_entry': original_entry
})
print(f"Number of results: {len(results)}")
print(f"Sample result: {results[0] if results else 'No results'}")
def calculate_TBT_score(results):
"""
• Tempo Bin with Tolerance (TBT): The predicted bpm falls into the ground truth tempo bin or
a neighboring one.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'tempo' in eval_entry and 'tempo' in original_entry:
eval_tempo = eval_entry['tempo'][0] if isinstance(eval_entry['tempo'], list) else eval_entry['tempo']
original_tempo = original_entry['tempo']
if original_tempo is None or eval_tempo is None:
continue # 如果原始条目没有 tempo跳过
# 检查 eval_tempo 是否在 original_tempo 的范围内
if original_tempo - 10 <= eval_tempo <= original_tempo + 15:
correct += 1
total += 1
TB_score = correct / total if total > 0 else 0
print(f"TB Score: {TB_score:.4f} (Correct: {correct}, Total: {total})")
return TB_score
def calculate_CK_score(results):
"""
• Correct Key (CK): The predicted key matches the ground truth key.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'key' in eval_entry and 'key' in original_entry:
eval_key = eval_entry['key'][0] if isinstance(eval_entry['key'], list) else eval_entry['key']
eval_key = eval_key if eval_key is not None else "C major" # 默认值为 C 大调
original_key = original_entry['key'] if original_entry['key'] is not None else "C major" # 默认值为 C 大调
if original_key is None or eval_key is None:
continue
if eval_key == original_key:
correct += 1
total += 1
CK_score = correct / total if total > 0 else 0
print(f"CK Score: {CK_score:.4f} (Correct: {correct}, Total: {total})")
return CK_score
def calculate_CKD_score(results):
"""
Correct Key with Duplicates (CKD): The predicted key matches the ground truth key or an equivalent key (i.e., a major key and its relative minor).
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'key' in eval_entry and 'key' in original_entry:
eval_key = eval_entry['key'][0] if isinstance(eval_entry['key'], list) else eval_entry['key']
if eval_key is None:
eval_key = "C major" # 默认值为 C 大调
original_key = original_entry['key'] if original_entry['key'] is not None else "C major"
if original_key is None or eval_key is None:
continue # 如果原始条目没有 key跳过
# 检查 eval_key 是否与 original_key 相同或是其相对小调
if eval_key == original_key or (eval_key.split(' ')[0] == original_key.split(' ')[0]):
correct += 1
total += 1
CKD_score = correct / total if total > 0 else 0
print(f"CKD Score: {CKD_score:.4f} (Correct: {correct}, Total: {total})")
return CKD_score
def calculate_CTS_score(results):
"""
• Correct Time Signature (CTS): The predicted time signature matches the ground truth time signature.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'time_signature' in eval_entry and 'time_signature' in original_entry:
eval_time_signature = eval_entry['time_signature'][0] if isinstance(eval_entry['time_signature'], list) else eval_entry['time_signature']
original_time_signature = original_entry['time_signature']
if original_time_signature is None or eval_time_signature is None:
continue # 如果原始条目没有 time signature跳过
if eval_time_signature == original_time_signature:
correct += 1
else:
# 检查是否为相同的节拍(如 4/4 和 2/2
eval_numerator, eval_denominator = map(int, eval_time_signature.split('/'))
original_numerator, original_denominator = map(int, original_time_signature.split('/'))
if (eval_numerator == original_numerator and eval_denominator == original_denominator) or \
(eval_numerator * 2 == original_numerator and eval_denominator == original_denominator):
correct += 1
total += 1
CTS_score = correct / total if total > 0 else 0
print(f"CTS Score: {CTS_score:.4f} (Correct: {correct}, Total: {total})")
return CTS_score
def calculate_ECM_score(results):
"""
Exact Chord Match (ECM): The predicted
chord sequence matches the ground truth exactly
in terms of order, chord root, and chord type, with
tolerance for missing and excess chord instances.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'chord_summary' in eval_entry and 'chord_summary' in original_entry:
eval_chord_summary = eval_entry['chord_summary'][0] if isinstance(eval_entry['chord_summary'], list) else eval_entry['chord_summary']
original_chord_summary = original_entry['chord_summary']
if original_chord_summary is None or eval_chord_summary is None:
continue
# 检查 eval_chord_summary 是否包含 original_chord_summary两个都是列表每个元素是一个字符串
if eval_chord_summary == original_chord_summary:
correct += 1
total += 1
ECM_score = correct / total if total > 0 else 0
print(f"ECM Score: {ECM_score:.4f} (Correct: {correct}, Total: {total})")
return ECM_score
def calculate_CMO_score(results):
"""
• Chord Match in any Order (CMO): The portion of predicted chord sequence matching the
ground truth chord root and type, in any order
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'chords' in eval_entry and 'chord_summary' in original_entry:
eval_chords_seq = eval_entry['chords']
# remove the confidence score from eval_chords_seq
if isinstance(eval_chords_seq, list) and len(eval_chords_seq) > 0 and isinstance(eval_chords_seq[0], list):
eval_chords_seq = [chord[0] for chord in eval_chords_seq]
original_chord_summary = original_entry['chord_summary']
if original_chord_summary is None or eval_chords_seq is None:
continue
# 检查 eval_chords_seq 是否包含 original_chord_summary两个都是列表
eval_chords_set = set(eval_chords_seq) # [['C', 0.464399092], ['G', 2.879274376]]
original_chord_set = set(original_chord_summary) # ['G', 'C']
if original_chord_set.issubset(eval_chords_set):
correct += 1
else:
if original_chord_set == eval_chords_set:
correct += 1
total += 1
CMO_score = correct / total if total > 0 else 0
print(f"CMO Score: {CMO_score:.4f} (Correct: {correct}, Total: {total})")
return CMO_score
def calculate_CI_score(results):
"""
•Correct Instrument (CI): The predicted instrument matches the ground truth instrument.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mapped_instruments_summary' in eval_entry and 'instrument_summary' in original_entry:
eval_instrument = eval_entry['mapped_instruments_summary'] if isinstance(eval_entry['mapped_instruments'], list) else eval_entry['mapped_instruments']
original_instrument = original_entry['instrument_summary']
if original_instrument is None or eval_instrument is None:
continue
# 检查 eval_instrument 是否包含 original_instrument
if isinstance(eval_instrument, list):
eval_instrument_set = set(eval_instrument)
original_instrument_set = set(original_instrument)
if original_instrument_set.issubset(eval_instrument_set):
correct += 1
else:
if eval_instrument == original_instrument:
correct += 1
total += 1
CI_score = correct / total if total > 0 else 0
print(f"CI Score: {CI_score:.4f} (Correct: {correct}, Total: {total})")
return CI_score
def calculate_CI_top1_score(results):
"""
•Correct Instrument Top-1 (CI_top1): The predicted instrument matches the ground truth instrument
or is one of the top 3 predicted instruments.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mapped_instruments_summary' in eval_entry and 'instrument_summary' in original_entry:
eval_instrument = eval_entry['mapped_instruments_summary'] if isinstance(eval_entry['mapped_instruments'], list) else eval_entry['mapped_instruments']
original_instrument = original_entry['instrument_summary']
if original_instrument is None or eval_instrument is None:
continue
# 检查 eval_instrument 是否包含 original_instrument中的一个元素
if isinstance(eval_instrument, list):
eval_instrument_set = set(eval_instrument)
original_instrument_set = set(original_instrument)
for inst in original_instrument_set:
if inst in eval_instrument_set:
correct += 1
break
else:
if eval_instrument == original_instrument:
correct += 1
total += 1
CI_top1_score = correct / total if total > 0 else 0
print(f"CI Top-1 Score: {CI_top1_score:.4f} (Correct: {correct}, Total: {total})")
return CI_top1_score
def calculate_CG_score(results):
"""
• Correct Genre (CG): The predicted genre matches the ground truth genre.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'genre' in eval_entry and 'genre' in original_entry:
eval_genre = eval_entry['genre'][0] if isinstance(eval_entry['genre'], list) else eval_entry['genre']
original_genre = original_entry['genre']
if original_genre is None or eval_genre is None:
continue
# 检查 eval_genre 是否包含 original_genre
if isinstance(eval_genre, list):
eval_genre_set = set(eval_genre)
original_genre_set = set(original_genre)
if original_genre_set.issubset(eval_genre_set):
correct += 1
else:
if eval_genre == original_genre:
correct += 1
total += 1
CG_score = correct / total if total > 0 else 0
print(f"CG Score: {CG_score:.4f} (Correct: {correct}, Total: {total})")
return CG_score
def calculate_CG_top1_score(results):
"""
• Correct Genre Top-1 (CG_top1): The predicted genre matches the ground truth genre or is one of the top 3 predicted genres.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'genre' in eval_entry and 'genre' in original_entry:
eval_genre = eval_entry['genre'][0] if isinstance(eval_entry['genre'], list) else eval_entry['genre']
original_genre = original_entry['genre']
if original_genre is None or eval_genre is None:
continue
# 检查 eval_genre 是否包含 original_genre中的一个元素
if isinstance(eval_genre, list):
eval_genre_set = set(eval_genre)
original_genre_set = set(original_genre)
for gen in original_genre_set:
if gen in eval_genre_set:
correct += 1
break
else:
if eval_genre == original_genre:
correct += 1
total += 1
CG_top1_score = correct / total if total > 0 else 0
print(f"CG Top-1 Score: {CG_top1_score:.4f} (Correct: {correct}, Total: {total})")
return CG_top1_score
def calculate_CM_score(results):
"""
• Correct Mood (CM): The predicted mood matches the ground truth mood.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mood' in eval_entry and 'mood' in original_entry:
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
original_mood = original_entry['mood']
if original_mood is None or eval_mood is None:
continue
# 检查 eval_mood 是否包含 original_mood
if isinstance(eval_mood, list):
eval_mood_set = set(eval_mood)
original_mood_set = set(original_mood)
if original_mood_set.issubset(eval_mood_set):
correct += 1
else:
if eval_mood == original_mood:
correct += 1
total += 1
CM_score = correct / total if total > 0 else 0
print(f"CM Score: {CM_score:.4f} (Correct: {correct}, Total: {total})")
return CM_score
def calculate_CM_top1_score(results):
"""
• Correct Mood Top-1 (CM_top1): The predicted mood matches the ground truth mood or is one of the top 3 predicted moods.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mood' in eval_entry and 'mood' in original_entry:
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
original_mood = original_entry['mood']
if original_mood is None or eval_mood is None:
continue
# 检查 eval_mood 是否包含 original_mood中的一个元素
if isinstance(eval_mood, list):
eval_mood_set = set(eval_mood)
original_mood_set = set(original_mood)
for mood in original_mood_set:
if mood in eval_mood_set:
correct += 1
break
else:
if eval_mood == original_mood:
correct += 1
total += 1
CM_top1_score = correct / total if total > 0 else 0
print(f"CM Top-1 Score: {CM_top1_score:.4f} (Correct: {correct}, Total: {total})")
return CM_top1_score
def calculate_CM_top3_score(results):
"""
• Correct Mood Top-3 (CM_top3): The predicted mood matches the ground truth mood or is one of the top 3 predicted moods.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mood' in eval_entry and 'mood' in original_entry:
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
original_mood = original_entry['mood']
if original_mood is None or eval_mood is None:
continue
# 检查 eval_mood 是否包含 original_mood中的3个元素
if isinstance(eval_mood, list):
eval_mood_set = set(eval_mood)
original_mood_set = set(original_mood)
if len(original_mood_set) <= 3 and original_mood_set.issubset(eval_mood_set):
correct += 1
elif len(original_mood_set) > 3:
match_num = sum(1 for mood in original_mood_set if mood in eval_mood_set)
if match_num >= 3:
correct += 1
else:
if eval_mood == original_mood:
correct += 1
total += 1
CM_top3_score = correct / total if total > 0 else 0
print(f"CM Top-3 Score: {CM_top3_score:.4f} (Correct: {correct}, Total: {total})")
return CM_top3_score
def calculate_all_scores(results):
"""
Calculate all scores and return them as a dictionary.
"""
scores = {
'TBT_score': calculate_TBT_score(results),
'CK_score': calculate_CK_score(results),
'CKD_score': calculate_CKD_score(results),
'CTS_score': calculate_CTS_score(results),
'ECM_score': calculate_ECM_score(results),
'CMO_score': calculate_CMO_score(results),
'CI_score': calculate_CI_score(results),
'CI_top1_score': calculate_CI_top1_score(results),
'CG_score': calculate_CG_score(results),
'CG_top1_score': calculate_CG_top1_score(results),
'CM_score': calculate_CM_score(results),
'CM_top1_score': calculate_CM_top1_score(results),
'CM_top3_score': calculate_CM_top3_score(results)
}
return scores
if __name__ == "__main__":
scores = calculate_all_scores(results)
print("All Scores:")
for score_name, score_value in scores.items():
print(f"{score_name}: {score_value:.4f}")
# Save the results to a JSON file
output_file = f"{generate_path}/results.json"
with open(output_file, 'w') as f:
json.dump(scores, f, indent=4)
print(f"Results saved to {output_file}")

103
SongEval/ebr.py Normal file
View File

@ -0,0 +1,103 @@
import argparse
import glob
import os
import pandas as pd
import muspy
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
def compute_midi_metrics(file_path):
"""计算单个MIDI文件的音乐指标"""
try:
music = muspy.read(file_path)
scale_consistency = muspy.scale_consistency(music)
pitch_entropy = muspy.pitch_entropy(music)
pitch_class_entropy = muspy.pitch_class_entropy(music)
empty_beat_rate = muspy.empty_beat_rate(music)
groove_consistency = muspy.groove_consistency(music, 12)
metrics = {
'scale_consistency': scale_consistency,
'pitch_entropy': pitch_entropy,
'pitch_class_entropy': pitch_class_entropy,
'empty_beat_rate': empty_beat_rate,
'groove_consistency': groove_consistency,
'filename': os.path.basename(file_path)
}
return metrics
except Exception as e:
print(f"处理文件 {os.path.basename(file_path)} 时出错: {str(e)}")
return None
def compute_directory_metrics(directory_path, num_workers=8):
"""计算目录下所有MIDI文件的音乐指标多线程加速"""
midi_files = []
for root, _, files in os.walk(directory_path):
for file in files:
if file.lower().endswith(('.mid', '.midi')):
midi_files.append(os.path.join(root, file))
if not midi_files:
print("目录及子文件夹中未找到MIDI文件")
return None
all_metrics = []
average_metrics = {
'scale_consistency': 0,
'pitch_entropy': 0,
'pitch_class_entropy': 0,
'empty_beat_rate': 0,
'groove_consistency': 0
}
current_num = 0
total_scale_consistency = 0
total_pitch_entropy = 0
total_pitch_class_entropy = 0
total_empty_beat_rate = 0
total_groove_consistency = 0
print(f"正在处理目录: {directory_path}")
print(f"发现 {len(midi_files)} 个MIDI文件:")
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(compute_midi_metrics, midi_file): midi_file for midi_file in midi_files}
for future in tqdm(as_completed(futures), total=len(midi_files), desc="处理中"):
metrics = future.result()
if metrics is not None:
current_num += 1
total_scale_consistency += metrics['scale_consistency']
total_pitch_entropy += metrics['pitch_entropy']
total_pitch_class_entropy += metrics['pitch_class_entropy']
total_empty_beat_rate += metrics['empty_beat_rate']
total_groove_consistency += metrics['groove_consistency']
average_metrics['scale_consistency'] = total_scale_consistency / current_num
average_metrics['pitch_entropy'] = total_pitch_entropy / current_num
average_metrics['pitch_class_entropy'] = total_pitch_class_entropy / current_num
average_metrics['empty_beat_rate'] = total_empty_beat_rate / current_num
average_metrics['groove_consistency'] = total_groove_consistency / current_num
print("current_metrics:", metrics)
all_metrics.append(metrics)
if not all_metrics:
print("所有文件处理失败")
return None
df = pd.DataFrame(all_metrics)
output_csv = os.path.join(directory_path, "midi_metrics_report.csv")
df.to_csv(output_csv, index=False)
avg_metrics = df.mean(numeric_only=True)
return df, avg_metrics
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="计算目录下所有MIDI文件的音乐指标")
parser.add_argument("path", type=str, help="包含MIDI文件的目录路径")
parser.add_argument("--threads", type=int, default=1, help="线程数默认8")
args = parser.parse_args()
if not os.path.isdir(args.path):
print(f"错误: 路径 '{args.path}' 不存在或不是目录")
else:
result, averages = compute_directory_metrics(args.path, num_workers=args.threads)
if result is not None:
print("\n计算完成! 结果已保存到 midi_metrics_report.csv")
print("\n平均指标值:")
print(averages.to_string())

150
SongEval/eval.py Normal file
View File

@ -0,0 +1,150 @@
import glob
import os
import json
import librosa
import numpy as np
import torch
import argparse
from muq import MuQ
from hydra.utils import instantiate
from omegaconf import OmegaConf
from safetensors.torch import load_file
from tqdm import tqdm
class Synthesizer(object):
def __init__(self,
checkpoint_path,
input_path,
output_dir,
use_cpu: bool = False):
self.checkpoint_path = checkpoint_path
self.input_path = input_path
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
self.device = torch.device('cuda') if (torch.cuda.is_available() and (not use_cpu)) else torch.device('cpu')
@torch.no_grad()
def setup(self):
train_config = OmegaConf.load(os.path.join(os.path.dirname(self.checkpoint_path), '../config.yaml'))
model = instantiate(train_config.generator).to(self.device).eval()
state_dict = load_file(self.checkpoint_path, device="cpu")
model.load_state_dict(state_dict, strict=False)
self.model = model
self.muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
self.muq = self.muq.to(self.device).eval()
self.result_dcit = {}
@torch.no_grad()
def synthesis(self):
if os.path.isfile(self.input_path):
if self.input_path.endswith(('.wav', '.mp3')):
lines = []
lines.append(self.input_path)
else:
with open(self.input_path, "r") as f:
lines = [line for line in f]
input_files = [{
"input_path": line.strip(),
} for line in lines]
print(f"input filelst: {self.input_path}")
elif os.path.isdir(self.input_path):
input_files = [{
"input_path": file,
}for file in glob.glob(os.path.join(self.input_path, '*')) if file.lower().endswith(('.wav', '.mp3'))]
else:
raise ValueError(f"input_path {self.input_path} is not a file or directory")
for input in tqdm(input_files):
try:
self.handle(**input)
except Exception as e:
print(e)
continue
# add average
avg_values = {}
for key in self.result_dcit[list(self.result_dcit.keys())[0]].keys():
avg_values[key] = round(np.mean([self.result_dcit[fid][key] for fid in self.result_dcit]), 4)
self.result_dcit['average'] = avg_values
# save result
with open(os.path.join(self.output_dir, "result.json") , "w")as f:
json.dump(self.result_dcit, f, indent=4, ensure_ascii=False)
@torch.no_grad()
def handle(self, input_path):
fid = os.path.basename(input_path).split('.')[0]
if input_path.endswith('.npy'):
input = np.load(input_path)
# check ssl
if len(input.shape) == 3 and input.shape[0] != 1:
print('ssl_shape error', input_path)
return
if np.isnan(input).any():
print('ssl nan', input_path)
return
input = torch.from_numpy(input).to(self.device)
if len(input.shape) == 2:
input = input.unsqueeze(0)
if input_path.endswith(('.wav', '.mp3')):
wav, sr = librosa.load(input_path, sr=24000)
audio = torch.tensor(wav).unsqueeze(0).to(self.device)
output = self.muq(audio, output_hidden_states=True)
input = output["hidden_states"][6]
values = {}
scores_g = self.model(input).squeeze(0)
values['Coherence'] = round(scores_g[0].item(), 4)
values['Musicality'] = round(scores_g[1].item(), 4)
values['Memorability'] = round(scores_g[2].item(), 4)
values['Clarity'] = round(scores_g[3].item(), 4)
values['Naturalness'] = round(scores_g[4].item(), 4)
self.result_dcit[fid] = values
# delete
del input, output, scores_g, values,audio, wav, sr
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"-i", "--input_path",
type=str,
required=True,
help="Input audio: path to a single file, a text file listing audio paths, or a directory of audio files."
)
parser.add_argument(
"-o", "--output_dir",
type=str,
required=True,
help="Output directory for generated results (will be created if it doesn't exist)."
)
parser.add_argument(
"--use_cpu",
type=str,
help="Force CPU mode even if a GPU is available.",
default=False
)
args = parser.parse_args()
ckpt_path = "ckpt/model.safetensors"
synthesizer = Synthesizer(checkpoint_path=ckpt_path,
input_path=args.input_path,
output_dir=args.output_dir,
use_cpu=args.use_cpu)
synthesizer.setup()
synthesizer.synthesis()

View File

@ -0,0 +1,404 @@
import sys
import os
from pathlib import Path
from multiprocessing import Process,set_start_method
import torch
import argparse
from omegaconf import OmegaConf
import json
from collections import defaultdict
from Amadeus.evaluation_utils import (
wandb_style_config_to_omega_config,
prepare_model_and_dataset_from_config,
get_best_ckpt_path_and_config,
Evaluator
)
from transformers import T5Tokenizer, T5EncoderModel
from Amadeus import model_zoo
from Amadeus.symbolic_encoding import data_utils
from Amadeus.model_zoo import AmadeusModel
from Amadeus.symbolic_encoding.data_utils import TuneCompiler
from Amadeus.symbolic_encoding.compile_utils import shift_and_pad
from Amadeus.symbolic_encoding.compile_utils import reverse_shift_and_pad_for_tensor
from Amadeus.symbolic_encoding import decoding_utils
from Amadeus.train_utils import adjust_prediction_order
from data_representation import vocab_utils
from data_representation.vocab_utils import LangTokenVocab
def get_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-wandb_exp_dir",
required=True,
type=str,
help="wandb experiment directory",
)
parser.add_argument(
"-generation_type",
type=str,
choices=('conditioned', 'unconditioned', 'text-conditioned'),
default='unconditioned',
help="generation type",
)
parser.add_argument(
"-sampling_method",
type=str,
choices=('top_p', 'top_k'),
default='top_p',
help="sampling method",
)
parser.add_argument(
"-threshold",
type=float,
default=0.99,
help="threshold",
)
parser.add_argument(
"-temperature",
type=float,
default=1.15,
help="temperature",
)
parser.add_argument(
"-num_samples",
type=int,
default=30,
help="number of samples to generate",
)
parser.add_argument(
"-num_target_measure",
type=int,
default=4,
help="number of target measures for conditioned generation",
)
parser.add_argument(
"-choose_selected_tunes",
action='store_true',
help="generate samples from selected tunes, only for SOD dataset",
)
parser.add_argument(
"-generate_length",
type=int,
default=1024,
help="length of the generated sequence",
)
parser.add_argument(
"-num_processes",
type=int,
default=2,
help="number of processes to use",
)
parser.add_argument(
"-gpu_ids",
type=str,
default="0,5",
help="comma-separated list of GPU IDs to use (e.g., '0,1,2,3')",
)
parser.add_argument(
"-prompt",
type=str,
default="With a rhythm of 100 BPM, this classical piece in 1/4 time signature in the key of Eb major creates a classical mood using String Ensemble, Pizzicato Strings, Tremolo Strings, Trumpet, Timpani.",
help="prompt for generation, only used for conditioned generation",
)
parser.add_argument(
"-prompt_file",
type=str,
default="dataset/midicaps/train.json",
help="file containing prompts for text-conditioned generation",
)
return parser
def load_resources(wandb_exp_dir, device):
"""Load model and dataset resources for a process"""
wandb_dir = Path('wandb')
ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, wandb_exp_dir)
config = OmegaConf.load(config_path)
config = wandb_style_config_to_omega_config(config)
# Load checkpoint to specified device
ckpt = torch.load(ckpt_path, map_location=device)
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path, vocab_path)
model.load_state_dict(ckpt['model'], strict=False)
model.to(device)
model.eval()
torch.compile(model)
print("total parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
# Prepare dataset for prompts
condition_list = [x[1] for x in test_set.data_list]
dataset_for_prompt = []
for i in range(len(condition_list)):
condition = test_set.get_segments_with_tune_idx(condition_list[i], 0)[0]
dataset_for_prompt.append((condition, condition_list[i]))
return config, model, dataset_for_prompt, vocab
def conditioned_worker(process_idx, gpu_id, args, data_slice):
"""Worker process for conditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"cond_{args.num_target_measure}m_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
base_path.mkdir(parents=True, exist_ok=True)
evaluator = Evaluator(config, model, dataset_for_prompt, vocab, device=device)
# Process assigned data slice
for idx, (tune_in_idx, tune_name) in enumerate(data_slice):
batch_dir = base_path / f"process_{process_idx}_batch_{idx}"
batch_dir.mkdir(parents=True, exist_ok=True)
evaluator.generate_samples_with_prompt(
batch_dir,
args.num_target_measure,
tune_in_idx,
tune_name,
config.data_params.first_pred_feature,
args.sampling_method,
args.threshold,
args.temperature,
generation_length=args.generate_length
)
def generate_samples_unconditioned(config, vocab, model, device,save_dir, num_samples, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
encoding_scheme = config.nn_params.encoding_scheme
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=vocab, in_beat_resolution=in_beat_resolution, dataset_name=config.dataset)
for i in range(num_samples):
generated_sample = model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
decoder(generated_sample, output_path=str(save_dir / f"{uid}_{i}.mid"))
def generate_samples_with_text_prompt(config, vocab, model, device, save_dir, prompt, first_pred_feature, sampling_method, threshold, temperature, generation_length=3072,uid=1):
encoding_scheme = config.nn_params.encoding_scheme
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-large')
encoder = T5EncoderModel.from_pretrained('google/flan-t5-large').to(device)
print(f"Using T5EncoderModel for text prompt: {prompt}")
context = tokenizer(prompt, return_tensors='pt', padding='max_length', truncation=True, max_length=128).to(device)
context = encoder(**context).last_hidden_state
in_beat_resolution_dict = {'Pop1k7': 4, 'Pop909': 4, 'SOD': 12, 'LakhClean': 4}
try:
in_beat_resolution = in_beat_resolution_dict[config.dataset]
except KeyError:
in_beat_resolution = 4 # Default resolution if dataset is not found
midi_decoder_dict = {'remi':'MidiDecoder4REMI', 'cp':'MidiDecoder4CP', 'nb':'MidiDecoder4NB'}
decoder_name = midi_decoder_dict[encoding_scheme]
decoder = getattr(decoding_utils, decoder_name)(vocab=vocab, in_beat_resolution=in_beat_resolution, dataset_name=config.dataset)
generated_sample = model.generate(0, generation_length, condition=None, num_target_measures=None, sampling_method=sampling_method, threshold=threshold, temperature=temperature, context=context)
if encoding_scheme == 'nb':
generated_sample = reverse_shift_and_pad_for_tensor(generated_sample, first_pred_feature)
# Open the jsonl file and count the number of lines to determine the current index
jsonl_path = save_dir / "name2prompt.jsonl"
if jsonl_path.exists():
with open(jsonl_path, 'r') as f:
current_idx = sum(1 for _ in f)
else:
current_idx = 0
name = f"prompt_{current_idx}"
name2prompt_dict = defaultdict(list)
name2prompt_dict[name].append(prompt)
with open(jsonl_path, 'a') as f:
f.write(json.dumps(name2prompt_dict) + '\n')
decoder(generated_sample, output_path=str(save_dir / f"{name}_{uid}.mid"))
def unconditioned_worker(process_idx, gpu_id, args, num_samples):
"""Worker process for unconditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"uncond_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
base_path.mkdir(parents=True, exist_ok=True)
# Generate assigned number of samples
batch_dir = base_path
generate_samples_unconditioned(
config,
vocab,
model,
batch_dir,
num_samples,
config.data_params.first_pred_feature,
args.sampling_method,
args.threshold,
args.temperature,
generation_length=args.generate_length,
uid=f"{process_idx}"
)
def text_conditioned_worker(process_idx, gpu_id, args, num_samples, data_slice):
"""Worker process for unconditioned generation"""
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}')
# Load resources with proper device
config, model, dataset_for_prompt, vocab = load_resources(args.wandb_exp_dir, device)
# Create output directory with process index
base_path = Path('wandb') / args.wandb_exp_dir / \
f"text_condi_{args.sampling_method}_t{args.threshold}_temp{args.temperature}"
base_path.mkdir(parents=True, exist_ok=True)
# Generate assigned number of samples
batch_dir = base_path
for idx, tune_name in enumerate(data_slice):
print(f"Process {process_idx} generating samples for tune: {tune_name}")
generate_samples_with_text_prompt(
config,
vocab,
model,
device,
batch_dir,
prompt=tune_name,
first_pred_feature=config.data_params.first_pred_feature,
sampling_method=args.sampling_method,
threshold=args.threshold,
temperature=args.temperature,
generation_length=args.generate_length,
uid=f"{process_idx}_{idx}"
)
def main():
# use spawn method for multiprocessing
set_start_method('spawn', force=True)
args = get_argument_parser().parse_args()
gpu_ids = list(map(int, args.gpu_ids.split(',')))
# Validate GPU availability
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
if len(gpu_ids) == 0:
raise ValueError("At least one GPU must be specified")
# Validate process count
if args.num_processes < 1:
raise ValueError("Number of processes must be at least 1")
if len(gpu_ids) < args.num_processes:
print(f"Warning: More processes ({args.num_processes}) than GPUs ({len(gpu_ids)}), some GPUs will be shared")
# Prepare data slices for processes
processes = []
try:
if args.generation_type == 'conditioned':
# Prepare selected tunes
wandb_dir = Path('wandb') / args.wandb_exp_dir
if not wandb_dir.exists():
raise FileNotFoundError(f"Experiment {args.wandb_exp_dir} not found")
# Load test set to get selected tunes (dummy load to get dataset info)
dummy_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_, test_set, _ = prepare_model_and_dataset_from_config(
wandb_dir / "files" / "config.yaml",
wandb_dir / "files" / "metadata.json",
wandb_dir / "files" / "vocab.json"
)
if args.choose_selected_tunes and test_set.dataset == 'SOD':
selected_tunes = ['Requiem_orch', 'magnificat_bwv-243_8_orch',
"Clarinet Concert in A Major: 2nd Movement, Adagio_orch"]
else:
selected_tunes = [name for _, name in test_set.data_list][:args.num_samples]
# Split selected data across processes
selected_data = [d for d in test_set.data_list if d[1] in selected_tunes]
chunk_size = (len(selected_data) + args.num_processes - 1) // args.num_processes
for i in range(args.num_processes):
start_idx = i * chunk_size
end_idx = min((i+1)*chunk_size, len(selected_data))
data_slice = selected_data[start_idx:end_idx]
if not data_slice:
continue
gpu_id = gpu_ids[i % len(gpu_ids)]
p = Process(
target=conditioned_worker,
args=(i, gpu_id, args, data_slice)
)
processes.append(p)
p.start()
elif args.generation_type == 'unconditioned':
samples_per_proc = args.num_samples // args.num_processes
remainder = args.num_samples % args.num_processes
for i in range(args.num_processes):
gpu_id = gpu_ids[i % len(gpu_ids)]
samples = samples_per_proc + (1 if i < remainder else 0)
if samples <= 0:
continue
p = Process(
target=unconditioned_worker,
args=(i, gpu_id, args, samples)
)
processes.append(p)
p.start()
elif args.generation_type == 'text-conditioned':
samples_per_proc = args.num_samples // args.num_processes
remainder = args.num_samples % args.num_processes
# Load prompts from file
prompt_name_list = []
with open(args.prompt_file, 'r') as f:
for line in f:
if not line.strip():
continue
prompt_data = json.loads(line.strip())
prompt_text = prompt_data['caption']
if prompt_data['test_set'] is True:
prompt_name_list.append(prompt_text)
print("length of prompt_name_list:", len(prompt_name_list))
if len(prompt_name_list) >= args.num_samples:
print(f"Reached the limit of {args.num_samples} prompts.")
break
for i in range(args.num_processes):
gpu_id = gpu_ids[i % len(gpu_ids)]
samples = samples_per_proc + (1 if i < remainder else 0)
if samples <= 0:
continue
# Split prompt names across processes
start_idx = i * (len(prompt_name_list) // args.num_processes)
end_idx = (i + 1) * (len(prompt_name_list) // args.num_processes)
data_slice = prompt_name_list[start_idx:end_idx]
p = Process(
target=text_conditioned_worker,
args=(i, gpu_id, args, samples, data_slice)
)
processes.append(p)
p.start()
# Wait for all processes to complete
for p in processes:
p.join()
except Exception as e:
print(f"Error in main process: {str(e)}")
for p in processes:
p.terminate()
raise
if __name__ == "__main__":
main()

68
SongEval/matrics.py Normal file
View File

@ -0,0 +1,68 @@
import argparse
import os
import shutil
import tempfile
import numpy as np
import torch
from audioldm_eval import EvaluationHelper, EvaluationHelperParallel
import torch.multiprocessing as mp
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--generation_path", type=str, required=True, help="Path to generated audio files")
parser.add_argument("--target_path", type=str, required=True, help="Path to reference audio files")
parser.add_argument("--force_paired", action="store_true", help="Force pairing by randomly selecting reference files")
parser.add_argument("--gpu_mode", choices=["single", "multi"], default="single", help="Evaluation mode")
parser.add_argument("--num_gpus", type=int, default=2, help="Number of GPUs for multi-GPU mode")
args = parser.parse_args()
# Handle forced pairing
target_eval_path = args.target_path
temp_dir = None
if args.force_paired:
print(f"Using forced pairing with reference files from {args.target_path}")
temp_dir = tempfile.mkdtemp()
target_eval_path = temp_dir
# Collect generated filenames
gen_files = []
for root, _, files in os.walk(args.generation_path):
for file in files:
if file.endswith(".wav"):
gen_files.append(file)
print(f"Found {len(gen_files)} generated files in {args.generation_path}")
# Collect all reference files
ref_files = []
for root, _, files in os.walk(args.target_path):
for file in files:
if file.endswith(".wav"):
ref_files.append(os.path.join(root, file))
# Select random references matching the count
selected_refs = np.random.choice(ref_files, len(gen_files), replace=False)
print(f"Selected {len(selected_refs)} reference files for evaluation.")
# Copy selected references to temp dir with generated filenames
for gen_file, ref_path in zip(gen_files, selected_refs):
shutil.copy(ref_path, os.path.join(temp_dir, gen_file))
device = torch.device(f"cuda:{0}") if args.gpu_mode == "single" else None
try:
if args.gpu_mode == "single":
print("Running single GPU evaluation...")
evaluator = EvaluationHelper(16000, device)
metrics = evaluator.main(args.generation_path, target_eval_path)
else:
print(f"Running multi-GPU evaluation on {args.num_gpus} GPUs...")
evaluator = EvaluationHelperParallel(16000, args.num_gpus)
metrics = evaluator.main(args.generation_path, target_eval_path)
print("Evaluation completed.")
finally:
# Clean up temporary directory
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
if __name__ == "__main__":
main()

66
SongEval/model.py Normal file
View File

@ -0,0 +1,66 @@
from einops import rearrange
import numpy as np
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self,
in_features,
ffd_hidden_size,
num_classes,
attn_layer_num,
):
super(Generator, self).__init__()
self.attn = nn.ModuleList(
[
nn.MultiheadAttention(
embed_dim=in_features,
num_heads=8,
dropout=0.2,
batch_first=True,
)
for _ in range(attn_layer_num)
]
)
self.ffd = nn.Sequential(
nn.Linear(in_features, ffd_hidden_size),
nn.ReLU(),
nn.Linear(ffd_hidden_size, in_features)
)
self.dropout = nn.Dropout(0.2)
self.fc = nn.Linear(in_features * 2, num_classes)
self.proj = nn.Tanh()
def forward(self, ssl_feature, judge_id=None):
'''
ssl_feature: [B, T, D]
output: [B, num_classes]
'''
B, T, D = ssl_feature.shape
ssl_feature = self.ffd(ssl_feature)
tmp_ssl_feature = ssl_feature
for attn in self.attn:
tmp_ssl_feature, _ = attn(tmp_ssl_feature, tmp_ssl_feature, tmp_ssl_feature)
ssl_feature = self.dropout(torch.concat([torch.mean(tmp_ssl_feature, dim=1), torch.max(ssl_feature, dim=1)[0]], dim=1)) # B, 2D
x = self.fc(ssl_feature) # B, num_classes
x = self.proj(x) * 2.0 + 3
return x

View File

@ -0,0 +1,4 @@
librosa==0.11.0
torch==2.7.0
muq==0.1.0
hydra-core==1.3.2

3397
SongEval/result.json Normal file

File diff suppressed because it is too large Load Diff