Files
MIDIFoundationModel/SongEval/controlability.py
2025-09-08 14:49:28 +08:00

457 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
generate_path = 'Text2midi/muzic/musecoco/2-attribute2music_model/generation/0505/linear_mask-1billion-attribute2music/infer_test/topk15-t1.0-ngram0/all_midis'
# generate_path = 'Text2midi/t2m-inferalign/text2midi_infer_output'
# generate_path = 'wandb/no-disp-no-ciem/text_condi_top_p_t0.99_temp1.25'
test_set_json = "dataset/midicaps/train.json"
generated_eval_json_path = f"{generate_path}/eval.json"
generated_name2prompt_jsonl_path = f"{generate_path}/name2prompt.jsonl"
# 1. 读取 test_set建立 prompt 到条目的映射
with open(test_set_json, 'r') as f:
test_set =[]
for line in f:
if not line.strip():
continue
item = json.loads(line.strip())
test_set.append(item)
prompt2item = {item['caption']: item for item in test_set if item['test_set'] is True}
print(f"Number of prompts in test set: {len(prompt2item)}")
# 2. 读取 name2prompt.jsonl建立 name 到 prompt 的映射
name2prompt = {}
with open(generated_name2prompt_jsonl_path, 'r') as f:
for line in f:
obj = json.loads(line)
name2prompt.update({k: v[0] for k, v in obj.items() if isinstance(v, list) and len(v) > 0})
# 3. 读取 eval.json
with open(generated_eval_json_path, 'r') as f:
eval_items = []
for line in f:
if not line.strip():
continue
item = json.loads(line.strip())
eval_items.append(item)
# 4. 对每个 name找到对应的 prompt确保 prompt 在 test_set 里,然后找到 eval.json 里对应的条目
results = []
# turn the name of eval_items into relative name
for item in eval_items:
item['name'] = item['name'].split('/')[-1] # 假设 name 是一个路径,取最后一部分作为相对名称
# 去掉第二个下划线后面的内容
if '_' in item['name']:
item['name'] = item['name'].split('.')[0].split('_')[0] + '_' + item['name'].split('.')[0].split('_')[1]
# print(f"Processed eval item name: {item['name']}")
for name, prompt in name2prompt.items():
if prompt not in prompt2item:
print(f"Prompt not found in test set: {prompt}")
continue
# 找到 eval.json 里对应的条目(假设 eval.json 里有 name 字段)
eval_entry = next((item for item in eval_items if item.get('name') == name), None)
if eval_entry is None:
print(f"Eval entry not found for name: {name}")
continue
# 原始条目
original_entry = prompt2item[prompt]
results.append({
'name': name,
'prompt': prompt,
'eval_entry': eval_entry,
'original_entry': original_entry
})
print(f"Number of results: {len(results)}")
print(f"Sample result: {results[0] if results else 'No results'}")
def calculate_TBT_score(results):
"""
• Tempo Bin with Tolerance (TBT): The predicted bpm falls into the ground truth tempo bin or
a neighboring one.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'tempo' in eval_entry and 'tempo' in original_entry:
eval_tempo = eval_entry['tempo'][0] if isinstance(eval_entry['tempo'], list) else eval_entry['tempo']
original_tempo = original_entry['tempo']
if original_tempo is None or eval_tempo is None:
continue # 如果原始条目没有 tempo跳过
# 检查 eval_tempo 是否在 original_tempo 的范围内
if original_tempo - 10 <= eval_tempo <= original_tempo + 15:
correct += 1
total += 1
TB_score = correct / total if total > 0 else 0
print(f"TB Score: {TB_score:.4f} (Correct: {correct}, Total: {total})")
return TB_score
def calculate_CK_score(results):
"""
• Correct Key (CK): The predicted key matches the ground truth key.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'key' in eval_entry and 'key' in original_entry:
eval_key = eval_entry['key'][0] if isinstance(eval_entry['key'], list) else eval_entry['key']
eval_key = eval_key if eval_key is not None else "C major" # 默认值为 C 大调
original_key = original_entry['key'] if original_entry['key'] is not None else "C major" # 默认值为 C 大调
if original_key is None or eval_key is None:
continue
if eval_key == original_key:
correct += 1
total += 1
CK_score = correct / total if total > 0 else 0
print(f"CK Score: {CK_score:.4f} (Correct: {correct}, Total: {total})")
return CK_score
def calculate_CKD_score(results):
"""
Correct Key with Duplicates (CKD): The predicted key matches the ground truth key or an equivalent key (i.e., a major key and its relative minor).
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'key' in eval_entry and 'key' in original_entry:
eval_key = eval_entry['key'][0] if isinstance(eval_entry['key'], list) else eval_entry['key']
if eval_key is None:
eval_key = "C major" # 默认值为 C 大调
original_key = original_entry['key'] if original_entry['key'] is not None else "C major"
if original_key is None or eval_key is None:
continue # 如果原始条目没有 key跳过
# 检查 eval_key 是否与 original_key 相同或是其相对小调
if eval_key == original_key or (eval_key.split(' ')[0] == original_key.split(' ')[0]):
correct += 1
total += 1
CKD_score = correct / total if total > 0 else 0
print(f"CKD Score: {CKD_score:.4f} (Correct: {correct}, Total: {total})")
return CKD_score
def calculate_CTS_score(results):
"""
• Correct Time Signature (CTS): The predicted time signature matches the ground truth time signature.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'time_signature' in eval_entry and 'time_signature' in original_entry:
eval_time_signature = eval_entry['time_signature'][0] if isinstance(eval_entry['time_signature'], list) else eval_entry['time_signature']
original_time_signature = original_entry['time_signature']
if original_time_signature is None or eval_time_signature is None:
continue # 如果原始条目没有 time signature跳过
if eval_time_signature == original_time_signature:
correct += 1
else:
# 检查是否为相同的节拍(如 4/4 和 2/2
eval_numerator, eval_denominator = map(int, eval_time_signature.split('/'))
original_numerator, original_denominator = map(int, original_time_signature.split('/'))
if (eval_numerator == original_numerator and eval_denominator == original_denominator) or \
(eval_numerator * 2 == original_numerator and eval_denominator == original_denominator):
correct += 1
total += 1
CTS_score = correct / total if total > 0 else 0
print(f"CTS Score: {CTS_score:.4f} (Correct: {correct}, Total: {total})")
return CTS_score
def calculate_ECM_score(results):
"""
Exact Chord Match (ECM): The predicted
chord sequence matches the ground truth exactly
in terms of order, chord root, and chord type, with
tolerance for missing and excess chord instances.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'chord_summary' in eval_entry and 'chord_summary' in original_entry:
eval_chord_summary = eval_entry['chord_summary'][0] if isinstance(eval_entry['chord_summary'], list) else eval_entry['chord_summary']
original_chord_summary = original_entry['chord_summary']
if original_chord_summary is None or eval_chord_summary is None:
continue
# 检查 eval_chord_summary 是否包含 original_chord_summary两个都是列表每个元素是一个字符串
if eval_chord_summary == original_chord_summary:
correct += 1
total += 1
ECM_score = correct / total if total > 0 else 0
print(f"ECM Score: {ECM_score:.4f} (Correct: {correct}, Total: {total})")
return ECM_score
def calculate_CMO_score(results):
"""
• Chord Match in any Order (CMO): The portion of predicted chord sequence matching the
ground truth chord root and type, in any order
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'chords' in eval_entry and 'chord_summary' in original_entry:
eval_chords_seq = eval_entry['chords']
# remove the confidence score from eval_chords_seq
if isinstance(eval_chords_seq, list) and len(eval_chords_seq) > 0 and isinstance(eval_chords_seq[0], list):
eval_chords_seq = [chord[0] for chord in eval_chords_seq]
original_chord_summary = original_entry['chord_summary']
if original_chord_summary is None or eval_chords_seq is None:
continue
# 检查 eval_chords_seq 是否包含 original_chord_summary两个都是列表
eval_chords_set = set(eval_chords_seq) # [['C', 0.464399092], ['G', 2.879274376]]
original_chord_set = set(original_chord_summary) # ['G', 'C']
if original_chord_set.issubset(eval_chords_set):
correct += 1
else:
if original_chord_set == eval_chords_set:
correct += 1
total += 1
CMO_score = correct / total if total > 0 else 0
print(f"CMO Score: {CMO_score:.4f} (Correct: {correct}, Total: {total})")
return CMO_score
def calculate_CI_score(results):
"""
•Correct Instrument (CI): The predicted instrument matches the ground truth instrument.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mapped_instruments_summary' in eval_entry and 'instrument_summary' in original_entry:
eval_instrument = eval_entry['mapped_instruments_summary'] if isinstance(eval_entry['mapped_instruments'], list) else eval_entry['mapped_instruments']
original_instrument = original_entry['instrument_summary']
if original_instrument is None or eval_instrument is None:
continue
# 检查 eval_instrument 是否包含 original_instrument
if isinstance(eval_instrument, list):
eval_instrument_set = set(eval_instrument)
original_instrument_set = set(original_instrument)
if original_instrument_set.issubset(eval_instrument_set):
correct += 1
else:
if eval_instrument == original_instrument:
correct += 1
total += 1
CI_score = correct / total if total > 0 else 0
print(f"CI Score: {CI_score:.4f} (Correct: {correct}, Total: {total})")
return CI_score
def calculate_CI_top1_score(results):
"""
•Correct Instrument Top-1 (CI_top1): The predicted instrument matches the ground truth instrument
or is one of the top 3 predicted instruments.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mapped_instruments_summary' in eval_entry and 'instrument_summary' in original_entry:
eval_instrument = eval_entry['mapped_instruments_summary'] if isinstance(eval_entry['mapped_instruments'], list) else eval_entry['mapped_instruments']
original_instrument = original_entry['instrument_summary']
if original_instrument is None or eval_instrument is None:
continue
# 检查 eval_instrument 是否包含 original_instrument中的一个元素
if isinstance(eval_instrument, list):
eval_instrument_set = set(eval_instrument)
original_instrument_set = set(original_instrument)
for inst in original_instrument_set:
if inst in eval_instrument_set:
correct += 1
break
else:
if eval_instrument == original_instrument:
correct += 1
total += 1
CI_top1_score = correct / total if total > 0 else 0
print(f"CI Top-1 Score: {CI_top1_score:.4f} (Correct: {correct}, Total: {total})")
return CI_top1_score
def calculate_CG_score(results):
"""
• Correct Genre (CG): The predicted genre matches the ground truth genre.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'genre' in eval_entry and 'genre' in original_entry:
eval_genre = eval_entry['genre'][0] if isinstance(eval_entry['genre'], list) else eval_entry['genre']
original_genre = original_entry['genre']
if original_genre is None or eval_genre is None:
continue
# 检查 eval_genre 是否包含 original_genre
if isinstance(eval_genre, list):
eval_genre_set = set(eval_genre)
original_genre_set = set(original_genre)
if original_genre_set.issubset(eval_genre_set):
correct += 1
else:
if eval_genre == original_genre:
correct += 1
total += 1
CG_score = correct / total if total > 0 else 0
print(f"CG Score: {CG_score:.4f} (Correct: {correct}, Total: {total})")
return CG_score
def calculate_CG_top1_score(results):
"""
• Correct Genre Top-1 (CG_top1): The predicted genre matches the ground truth genre or is one of the top 3 predicted genres.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'genre' in eval_entry and 'genre' in original_entry:
eval_genre = eval_entry['genre'][0] if isinstance(eval_entry['genre'], list) else eval_entry['genre']
original_genre = original_entry['genre']
if original_genre is None or eval_genre is None:
continue
# 检查 eval_genre 是否包含 original_genre中的一个元素
if isinstance(eval_genre, list):
eval_genre_set = set(eval_genre)
original_genre_set = set(original_genre)
for gen in original_genre_set:
if gen in eval_genre_set:
correct += 1
break
else:
if eval_genre == original_genre:
correct += 1
total += 1
CG_top1_score = correct / total if total > 0 else 0
print(f"CG Top-1 Score: {CG_top1_score:.4f} (Correct: {correct}, Total: {total})")
return CG_top1_score
def calculate_CM_score(results):
"""
• Correct Mood (CM): The predicted mood matches the ground truth mood.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mood' in eval_entry and 'mood' in original_entry:
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
original_mood = original_entry['mood']
if original_mood is None or eval_mood is None:
continue
# 检查 eval_mood 是否包含 original_mood
if isinstance(eval_mood, list):
eval_mood_set = set(eval_mood)
original_mood_set = set(original_mood)
if original_mood_set.issubset(eval_mood_set):
correct += 1
else:
if eval_mood == original_mood:
correct += 1
total += 1
CM_score = correct / total if total > 0 else 0
print(f"CM Score: {CM_score:.4f} (Correct: {correct}, Total: {total})")
return CM_score
def calculate_CM_top1_score(results):
"""
• Correct Mood Top-1 (CM_top1): The predicted mood matches the ground truth mood or is one of the top 3 predicted moods.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mood' in eval_entry and 'mood' in original_entry:
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
original_mood = original_entry['mood']
if original_mood is None or eval_mood is None:
continue
# 检查 eval_mood 是否包含 original_mood中的一个元素
if isinstance(eval_mood, list):
eval_mood_set = set(eval_mood)
original_mood_set = set(original_mood)
for mood in original_mood_set:
if mood in eval_mood_set:
correct += 1
break
else:
if eval_mood == original_mood:
correct += 1
total += 1
CM_top1_score = correct / total if total > 0 else 0
print(f"CM Top-1 Score: {CM_top1_score:.4f} (Correct: {correct}, Total: {total})")
return CM_top1_score
def calculate_CM_top3_score(results):
"""
• Correct Mood Top-3 (CM_top3): The predicted mood matches the ground truth mood or is one of the top 3 predicted moods.
"""
correct = 0
total = 0
for result in results:
eval_entry = result['eval_entry']
original_entry = result['original_entry']
if 'mood' in eval_entry and 'mood' in original_entry:
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
original_mood = original_entry['mood']
if original_mood is None or eval_mood is None:
continue
# 检查 eval_mood 是否包含 original_mood中的3个元素
if isinstance(eval_mood, list):
eval_mood_set = set(eval_mood)
original_mood_set = set(original_mood)
if len(original_mood_set) <= 3 and original_mood_set.issubset(eval_mood_set):
correct += 1
elif len(original_mood_set) > 3:
match_num = sum(1 for mood in original_mood_set if mood in eval_mood_set)
if match_num >= 3:
correct += 1
else:
if eval_mood == original_mood:
correct += 1
total += 1
CM_top3_score = correct / total if total > 0 else 0
print(f"CM Top-3 Score: {CM_top3_score:.4f} (Correct: {correct}, Total: {total})")
return CM_top3_score
def calculate_all_scores(results):
"""
Calculate all scores and return them as a dictionary.
"""
scores = {
'TBT_score': calculate_TBT_score(results),
'CK_score': calculate_CK_score(results),
'CKD_score': calculate_CKD_score(results),
'CTS_score': calculate_CTS_score(results),
'ECM_score': calculate_ECM_score(results),
'CMO_score': calculate_CMO_score(results),
'CI_score': calculate_CI_score(results),
'CI_top1_score': calculate_CI_top1_score(results),
'CG_score': calculate_CG_score(results),
'CG_top1_score': calculate_CG_top1_score(results),
'CM_score': calculate_CM_score(results),
'CM_top1_score': calculate_CM_top1_score(results),
'CM_top3_score': calculate_CM_top3_score(results)
}
return scores
if __name__ == "__main__":
scores = calculate_all_scores(results)
print("All Scores:")
for score_name, score_value in scores.items():
print(f"{score_name}: {score_value:.4f}")
# Save the results to a JSON file
output_file = f"{generate_path}/results.json"
with open(output_file, 'w') as f:
json.dump(scores, f, indent=4)
print(f"Results saved to {output_file}")