457 lines
20 KiB
Python
457 lines
20 KiB
Python
import json
|
||
|
||
generate_path = 'Text2midi/muzic/musecoco/2-attribute2music_model/generation/0505/linear_mask-1billion-attribute2music/infer_test/topk15-t1.0-ngram0/all_midis'
|
||
# generate_path = 'Text2midi/t2m-inferalign/text2midi_infer_output'
|
||
# generate_path = 'wandb/no-disp-no-ciem/text_condi_top_p_t0.99_temp1.25'
|
||
test_set_json = "dataset/midicaps/train.json"
|
||
|
||
generated_eval_json_path = f"{generate_path}/eval.json"
|
||
generated_name2prompt_jsonl_path = f"{generate_path}/name2prompt.jsonl"
|
||
|
||
# 1. 读取 test_set,建立 prompt 到条目的映射
|
||
with open(test_set_json, 'r') as f:
|
||
test_set =[]
|
||
for line in f:
|
||
if not line.strip():
|
||
continue
|
||
item = json.loads(line.strip())
|
||
test_set.append(item)
|
||
prompt2item = {item['caption']: item for item in test_set if item['test_set'] is True}
|
||
print(f"Number of prompts in test set: {len(prompt2item)}")
|
||
# 2. 读取 name2prompt.jsonl,建立 name 到 prompt 的映射
|
||
name2prompt = {}
|
||
with open(generated_name2prompt_jsonl_path, 'r') as f:
|
||
for line in f:
|
||
obj = json.loads(line)
|
||
name2prompt.update({k: v[0] for k, v in obj.items() if isinstance(v, list) and len(v) > 0})
|
||
# 3. 读取 eval.json
|
||
with open(generated_eval_json_path, 'r') as f:
|
||
eval_items = []
|
||
for line in f:
|
||
if not line.strip():
|
||
continue
|
||
item = json.loads(line.strip())
|
||
eval_items.append(item)
|
||
|
||
# 4. 对每个 name,找到对应的 prompt,确保 prompt 在 test_set 里,然后找到 eval.json 里对应的条目
|
||
results = []
|
||
# turn the name of eval_items into relative name
|
||
for item in eval_items:
|
||
item['name'] = item['name'].split('/')[-1] # 假设 name 是一个路径,取最后一部分作为相对名称
|
||
# 去掉第二个下划线后面的内容
|
||
if '_' in item['name']:
|
||
item['name'] = item['name'].split('.')[0].split('_')[0] + '_' + item['name'].split('.')[0].split('_')[1]
|
||
# print(f"Processed eval item name: {item['name']}")
|
||
|
||
for name, prompt in name2prompt.items():
|
||
if prompt not in prompt2item:
|
||
print(f"Prompt not found in test set: {prompt}")
|
||
continue
|
||
# 找到 eval.json 里对应的条目(假设 eval.json 里有 name 字段)
|
||
eval_entry = next((item for item in eval_items if item.get('name') == name), None)
|
||
if eval_entry is None:
|
||
print(f"Eval entry not found for name: {name}")
|
||
continue
|
||
# 原始条目
|
||
original_entry = prompt2item[prompt]
|
||
results.append({
|
||
'name': name,
|
||
'prompt': prompt,
|
||
'eval_entry': eval_entry,
|
||
'original_entry': original_entry
|
||
})
|
||
print(f"Number of results: {len(results)}")
|
||
print(f"Sample result: {results[0] if results else 'No results'}")
|
||
|
||
def calculate_TBT_score(results):
|
||
"""
|
||
• Tempo Bin with Tolerance (TBT): The predicted bpm falls into the ground truth tempo bin or
|
||
a neighboring one.
|
||
"""
|
||
correct = 0
|
||
total = 0
|
||
for result in results:
|
||
eval_entry = result['eval_entry']
|
||
original_entry = result['original_entry']
|
||
if 'tempo' in eval_entry and 'tempo' in original_entry:
|
||
eval_tempo = eval_entry['tempo'][0] if isinstance(eval_entry['tempo'], list) else eval_entry['tempo']
|
||
original_tempo = original_entry['tempo']
|
||
if original_tempo is None or eval_tempo is None:
|
||
continue # 如果原始条目没有 tempo,跳过
|
||
# 检查 eval_tempo 是否在 original_tempo 的范围内
|
||
if original_tempo - 10 <= eval_tempo <= original_tempo + 15:
|
||
correct += 1
|
||
total += 1
|
||
TB_score = correct / total if total > 0 else 0
|
||
print(f"TB Score: {TB_score:.4f} (Correct: {correct}, Total: {total})")
|
||
return TB_score
|
||
|
||
def calculate_CK_score(results):
|
||
"""
|
||
• Correct Key (CK): The predicted key matches the ground truth key.
|
||
"""
|
||
correct = 0
|
||
total = 0
|
||
for result in results:
|
||
eval_entry = result['eval_entry']
|
||
original_entry = result['original_entry']
|
||
if 'key' in eval_entry and 'key' in original_entry:
|
||
eval_key = eval_entry['key'][0] if isinstance(eval_entry['key'], list) else eval_entry['key']
|
||
eval_key = eval_key if eval_key is not None else "C major" # 默认值为 C 大调
|
||
original_key = original_entry['key'] if original_entry['key'] is not None else "C major" # 默认值为 C 大调
|
||
if original_key is None or eval_key is None:
|
||
continue
|
||
if eval_key == original_key:
|
||
correct += 1
|
||
total += 1
|
||
CK_score = correct / total if total > 0 else 0
|
||
print(f"CK Score: {CK_score:.4f} (Correct: {correct}, Total: {total})")
|
||
return CK_score
|
||
def calculate_CKD_score(results):
|
||
"""
|
||
Correct Key with Duplicates (CKD): The predicted key matches the ground truth key or an equivalent key (i.e., a major key and its relative minor).
|
||
"""
|
||
correct = 0
|
||
total = 0
|
||
for result in results:
|
||
eval_entry = result['eval_entry']
|
||
original_entry = result['original_entry']
|
||
if 'key' in eval_entry and 'key' in original_entry:
|
||
eval_key = eval_entry['key'][0] if isinstance(eval_entry['key'], list) else eval_entry['key']
|
||
if eval_key is None:
|
||
eval_key = "C major" # 默认值为 C 大调
|
||
original_key = original_entry['key'] if original_entry['key'] is not None else "C major"
|
||
if original_key is None or eval_key is None:
|
||
continue # 如果原始条目没有 key,跳过
|
||
# 检查 eval_key 是否与 original_key 相同或是其相对小调
|
||
if eval_key == original_key or (eval_key.split(' ')[0] == original_key.split(' ')[0]):
|
||
correct += 1
|
||
total += 1
|
||
CKD_score = correct / total if total > 0 else 0
|
||
print(f"CKD Score: {CKD_score:.4f} (Correct: {correct}, Total: {total})")
|
||
return CKD_score
|
||
|
||
def calculate_CTS_score(results):
|
||
"""
|
||
• Correct Time Signature (CTS): The predicted time signature matches the ground truth time signature.
|
||
"""
|
||
correct = 0
|
||
total = 0
|
||
for result in results:
|
||
eval_entry = result['eval_entry']
|
||
original_entry = result['original_entry']
|
||
if 'time_signature' in eval_entry and 'time_signature' in original_entry:
|
||
eval_time_signature = eval_entry['time_signature'][0] if isinstance(eval_entry['time_signature'], list) else eval_entry['time_signature']
|
||
original_time_signature = original_entry['time_signature']
|
||
if original_time_signature is None or eval_time_signature is None:
|
||
continue # 如果原始条目没有 time signature,跳过
|
||
if eval_time_signature == original_time_signature:
|
||
correct += 1
|
||
else:
|
||
# 检查是否为相同的节拍(如 4/4 和 2/2)
|
||
eval_numerator, eval_denominator = map(int, eval_time_signature.split('/'))
|
||
original_numerator, original_denominator = map(int, original_time_signature.split('/'))
|
||
if (eval_numerator == original_numerator and eval_denominator == original_denominator) or \
|
||
(eval_numerator * 2 == original_numerator and eval_denominator == original_denominator):
|
||
correct += 1
|
||
total += 1
|
||
CTS_score = correct / total if total > 0 else 0
|
||
print(f"CTS Score: {CTS_score:.4f} (Correct: {correct}, Total: {total})")
|
||
return CTS_score
|
||
|
||
def calculate_ECM_score(results):
|
||
"""
|
||
Exact Chord Match (ECM): The predicted
|
||
chord sequence matches the ground truth exactly
|
||
in terms of order, chord root, and chord type, with
|
||
tolerance for missing and excess chord instances.
|
||
"""
|
||
correct = 0
|
||
total = 0
|
||
for result in results:
|
||
eval_entry = result['eval_entry']
|
||
original_entry = result['original_entry']
|
||
if 'chord_summary' in eval_entry and 'chord_summary' in original_entry:
|
||
eval_chord_summary = eval_entry['chord_summary'][0] if isinstance(eval_entry['chord_summary'], list) else eval_entry['chord_summary']
|
||
original_chord_summary = original_entry['chord_summary']
|
||
if original_chord_summary is None or eval_chord_summary is None:
|
||
continue
|
||
# 检查 eval_chord_summary 是否包含 original_chord_summary,两个都是列表,每个元素是一个字符串
|
||
if eval_chord_summary == original_chord_summary:
|
||
correct += 1
|
||
total += 1
|
||
ECM_score = correct / total if total > 0 else 0
|
||
print(f"ECM Score: {ECM_score:.4f} (Correct: {correct}, Total: {total})")
|
||
return ECM_score
|
||
|
||
def calculate_CMO_score(results):
|
||
"""
|
||
• Chord Match in any Order (CMO): The portion of predicted chord sequence matching the
|
||
ground truth chord root and type, in any order
|
||
"""
|
||
correct = 0
|
||
total = 0
|
||
for result in results:
|
||
eval_entry = result['eval_entry']
|
||
original_entry = result['original_entry']
|
||
if 'chords' in eval_entry and 'chord_summary' in original_entry:
|
||
eval_chords_seq = eval_entry['chords']
|
||
# remove the confidence score from eval_chords_seq
|
||
if isinstance(eval_chords_seq, list) and len(eval_chords_seq) > 0 and isinstance(eval_chords_seq[0], list):
|
||
eval_chords_seq = [chord[0] for chord in eval_chords_seq]
|
||
original_chord_summary = original_entry['chord_summary']
|
||
if original_chord_summary is None or eval_chords_seq is None:
|
||
continue
|
||
# 检查 eval_chords_seq 是否包含 original_chord_summary,两个都是列表
|
||
eval_chords_set = set(eval_chords_seq) # [['C', 0.464399092], ['G', 2.879274376]]
|
||
original_chord_set = set(original_chord_summary) # ['G', 'C']
|
||
if original_chord_set.issubset(eval_chords_set):
|
||
correct += 1
|
||
else:
|
||
if original_chord_set == eval_chords_set:
|
||
correct += 1
|
||
total += 1
|
||
CMO_score = correct / total if total > 0 else 0
|
||
print(f"CMO Score: {CMO_score:.4f} (Correct: {correct}, Total: {total})")
|
||
return CMO_score
|
||
|
||
def calculate_CI_score(results):
|
||
"""
|
||
•Correct Instrument (CI): The predicted instrument matches the ground truth instrument.
|
||
"""
|
||
correct = 0
|
||
total = 0
|
||
for result in results:
|
||
eval_entry = result['eval_entry']
|
||
original_entry = result['original_entry']
|
||
if 'mapped_instruments_summary' in eval_entry and 'instrument_summary' in original_entry:
|
||
eval_instrument = eval_entry['mapped_instruments_summary'] if isinstance(eval_entry['mapped_instruments'], list) else eval_entry['mapped_instruments']
|
||
original_instrument = original_entry['instrument_summary']
|
||
if original_instrument is None or eval_instrument is None:
|
||
continue
|
||
# 检查 eval_instrument 是否包含 original_instrument
|
||
if isinstance(eval_instrument, list):
|
||
eval_instrument_set = set(eval_instrument)
|
||
original_instrument_set = set(original_instrument)
|
||
if original_instrument_set.issubset(eval_instrument_set):
|
||
correct += 1
|
||
else:
|
||
if eval_instrument == original_instrument:
|
||
correct += 1
|
||
total += 1
|
||
CI_score = correct / total if total > 0 else 0
|
||
print(f"CI Score: {CI_score:.4f} (Correct: {correct}, Total: {total})")
|
||
return CI_score
|
||
|
||
def calculate_CI_top1_score(results):
|
||
"""
|
||
•Correct Instrument Top-1 (CI_top1): The predicted instrument matches the ground truth instrument
|
||
or is one of the top 3 predicted instruments.
|
||
"""
|
||
correct = 0
|
||
total = 0
|
||
for result in results:
|
||
eval_entry = result['eval_entry']
|
||
original_entry = result['original_entry']
|
||
if 'mapped_instruments_summary' in eval_entry and 'instrument_summary' in original_entry:
|
||
eval_instrument = eval_entry['mapped_instruments_summary'] if isinstance(eval_entry['mapped_instruments'], list) else eval_entry['mapped_instruments']
|
||
original_instrument = original_entry['instrument_summary']
|
||
if original_instrument is None or eval_instrument is None:
|
||
continue
|
||
# 检查 eval_instrument 是否包含 original_instrument中的一个元素
|
||
if isinstance(eval_instrument, list):
|
||
eval_instrument_set = set(eval_instrument)
|
||
original_instrument_set = set(original_instrument)
|
||
for inst in original_instrument_set:
|
||
if inst in eval_instrument_set:
|
||
correct += 1
|
||
break
|
||
else:
|
||
if eval_instrument == original_instrument:
|
||
correct += 1
|
||
total += 1
|
||
CI_top1_score = correct / total if total > 0 else 0
|
||
print(f"CI Top-1 Score: {CI_top1_score:.4f} (Correct: {correct}, Total: {total})")
|
||
return CI_top1_score
|
||
|
||
def calculate_CG_score(results):
|
||
"""
|
||
• Correct Genre (CG): The predicted genre matches the ground truth genre.
|
||
"""
|
||
correct = 0
|
||
total = 0
|
||
for result in results:
|
||
eval_entry = result['eval_entry']
|
||
original_entry = result['original_entry']
|
||
if 'genre' in eval_entry and 'genre' in original_entry:
|
||
eval_genre = eval_entry['genre'][0] if isinstance(eval_entry['genre'], list) else eval_entry['genre']
|
||
original_genre = original_entry['genre']
|
||
if original_genre is None or eval_genre is None:
|
||
continue
|
||
# 检查 eval_genre 是否包含 original_genre
|
||
if isinstance(eval_genre, list):
|
||
eval_genre_set = set(eval_genre)
|
||
original_genre_set = set(original_genre)
|
||
if original_genre_set.issubset(eval_genre_set):
|
||
correct += 1
|
||
else:
|
||
if eval_genre == original_genre:
|
||
correct += 1
|
||
total += 1
|
||
CG_score = correct / total if total > 0 else 0
|
||
print(f"CG Score: {CG_score:.4f} (Correct: {correct}, Total: {total})")
|
||
return CG_score
|
||
|
||
def calculate_CG_top1_score(results):
|
||
"""
|
||
• Correct Genre Top-1 (CG_top1): The predicted genre matches the ground truth genre or is one of the top 3 predicted genres.
|
||
"""
|
||
correct = 0
|
||
total = 0
|
||
for result in results:
|
||
eval_entry = result['eval_entry']
|
||
original_entry = result['original_entry']
|
||
if 'genre' in eval_entry and 'genre' in original_entry:
|
||
eval_genre = eval_entry['genre'][0] if isinstance(eval_entry['genre'], list) else eval_entry['genre']
|
||
original_genre = original_entry['genre']
|
||
if original_genre is None or eval_genre is None:
|
||
continue
|
||
# 检查 eval_genre 是否包含 original_genre中的一个元素
|
||
if isinstance(eval_genre, list):
|
||
eval_genre_set = set(eval_genre)
|
||
original_genre_set = set(original_genre)
|
||
for gen in original_genre_set:
|
||
if gen in eval_genre_set:
|
||
correct += 1
|
||
break
|
||
else:
|
||
if eval_genre == original_genre:
|
||
correct += 1
|
||
total += 1
|
||
CG_top1_score = correct / total if total > 0 else 0
|
||
print(f"CG Top-1 Score: {CG_top1_score:.4f} (Correct: {correct}, Total: {total})")
|
||
return CG_top1_score
|
||
|
||
def calculate_CM_score(results):
|
||
"""
|
||
• Correct Mood (CM): The predicted mood matches the ground truth mood.
|
||
"""
|
||
correct = 0
|
||
total = 0
|
||
for result in results:
|
||
eval_entry = result['eval_entry']
|
||
original_entry = result['original_entry']
|
||
if 'mood' in eval_entry and 'mood' in original_entry:
|
||
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
|
||
original_mood = original_entry['mood']
|
||
if original_mood is None or eval_mood is None:
|
||
continue
|
||
# 检查 eval_mood 是否包含 original_mood
|
||
if isinstance(eval_mood, list):
|
||
eval_mood_set = set(eval_mood)
|
||
original_mood_set = set(original_mood)
|
||
if original_mood_set.issubset(eval_mood_set):
|
||
correct += 1
|
||
else:
|
||
if eval_mood == original_mood:
|
||
correct += 1
|
||
total += 1
|
||
CM_score = correct / total if total > 0 else 0
|
||
print(f"CM Score: {CM_score:.4f} (Correct: {correct}, Total: {total})")
|
||
return CM_score
|
||
|
||
def calculate_CM_top1_score(results):
|
||
"""
|
||
• Correct Mood Top-1 (CM_top1): The predicted mood matches the ground truth mood or is one of the top 3 predicted moods.
|
||
"""
|
||
correct = 0
|
||
total = 0
|
||
for result in results:
|
||
eval_entry = result['eval_entry']
|
||
original_entry = result['original_entry']
|
||
if 'mood' in eval_entry and 'mood' in original_entry:
|
||
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
|
||
original_mood = original_entry['mood']
|
||
if original_mood is None or eval_mood is None:
|
||
continue
|
||
# 检查 eval_mood 是否包含 original_mood中的一个元素
|
||
if isinstance(eval_mood, list):
|
||
eval_mood_set = set(eval_mood)
|
||
original_mood_set = set(original_mood)
|
||
for mood in original_mood_set:
|
||
if mood in eval_mood_set:
|
||
correct += 1
|
||
break
|
||
else:
|
||
if eval_mood == original_mood:
|
||
correct += 1
|
||
total += 1
|
||
CM_top1_score = correct / total if total > 0 else 0
|
||
print(f"CM Top-1 Score: {CM_top1_score:.4f} (Correct: {correct}, Total: {total})")
|
||
return CM_top1_score
|
||
|
||
def calculate_CM_top3_score(results):
|
||
"""
|
||
• Correct Mood Top-3 (CM_top3): The predicted mood matches the ground truth mood or is one of the top 3 predicted moods.
|
||
"""
|
||
correct = 0
|
||
total = 0
|
||
for result in results:
|
||
eval_entry = result['eval_entry']
|
||
original_entry = result['original_entry']
|
||
if 'mood' in eval_entry and 'mood' in original_entry:
|
||
eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood']
|
||
original_mood = original_entry['mood']
|
||
if original_mood is None or eval_mood is None:
|
||
continue
|
||
# 检查 eval_mood 是否包含 original_mood中的3个元素
|
||
if isinstance(eval_mood, list):
|
||
eval_mood_set = set(eval_mood)
|
||
original_mood_set = set(original_mood)
|
||
if len(original_mood_set) <= 3 and original_mood_set.issubset(eval_mood_set):
|
||
correct += 1
|
||
elif len(original_mood_set) > 3:
|
||
match_num = sum(1 for mood in original_mood_set if mood in eval_mood_set)
|
||
if match_num >= 3:
|
||
correct += 1
|
||
else:
|
||
if eval_mood == original_mood:
|
||
correct += 1
|
||
total += 1
|
||
CM_top3_score = correct / total if total > 0 else 0
|
||
print(f"CM Top-3 Score: {CM_top3_score:.4f} (Correct: {correct}, Total: {total})")
|
||
return CM_top3_score
|
||
|
||
def calculate_all_scores(results):
|
||
"""
|
||
Calculate all scores and return them as a dictionary.
|
||
"""
|
||
scores = {
|
||
'TBT_score': calculate_TBT_score(results),
|
||
'CK_score': calculate_CK_score(results),
|
||
'CKD_score': calculate_CKD_score(results),
|
||
'CTS_score': calculate_CTS_score(results),
|
||
'ECM_score': calculate_ECM_score(results),
|
||
'CMO_score': calculate_CMO_score(results),
|
||
'CI_score': calculate_CI_score(results),
|
||
'CI_top1_score': calculate_CI_top1_score(results),
|
||
'CG_score': calculate_CG_score(results),
|
||
'CG_top1_score': calculate_CG_top1_score(results),
|
||
'CM_score': calculate_CM_score(results),
|
||
'CM_top1_score': calculate_CM_top1_score(results),
|
||
'CM_top3_score': calculate_CM_top3_score(results)
|
||
}
|
||
return scores
|
||
if __name__ == "__main__":
|
||
scores = calculate_all_scores(results)
|
||
print("All Scores:")
|
||
for score_name, score_value in scores.items():
|
||
print(f"{score_name}: {score_value:.4f}")
|
||
|
||
# Save the results to a JSON file
|
||
output_file = f"{generate_path}/results.json"
|
||
with open(output_file, 'w') as f:
|
||
json.dump(scores, f, indent=4)
|
||
print(f"Results saved to {output_file}")
|
||
|