import json generate_path = 'Text2midi/muzic/musecoco/2-attribute2music_model/generation/0505/linear_mask-1billion-attribute2music/infer_test/topk15-t1.0-ngram0/all_midis' # generate_path = 'Text2midi/t2m-inferalign/text2midi_infer_output' # generate_path = 'wandb/no-disp-no-ciem/text_condi_top_p_t0.99_temp1.25' test_set_json = "dataset/midicaps/train.json" generated_eval_json_path = f"{generate_path}/eval.json" generated_name2prompt_jsonl_path = f"{generate_path}/name2prompt.jsonl" # 1. 读取 test_set,建立 prompt 到条目的映射 with open(test_set_json, 'r') as f: test_set =[] for line in f: if not line.strip(): continue item = json.loads(line.strip()) test_set.append(item) prompt2item = {item['caption']: item for item in test_set if item['test_set'] is True} print(f"Number of prompts in test set: {len(prompt2item)}") # 2. 读取 name2prompt.jsonl,建立 name 到 prompt 的映射 name2prompt = {} with open(generated_name2prompt_jsonl_path, 'r') as f: for line in f: obj = json.loads(line) name2prompt.update({k: v[0] for k, v in obj.items() if isinstance(v, list) and len(v) > 0}) # 3. 读取 eval.json with open(generated_eval_json_path, 'r') as f: eval_items = [] for line in f: if not line.strip(): continue item = json.loads(line.strip()) eval_items.append(item) # 4. 对每个 name,找到对应的 prompt,确保 prompt 在 test_set 里,然后找到 eval.json 里对应的条目 results = [] # turn the name of eval_items into relative name for item in eval_items: item['name'] = item['name'].split('/')[-1] # 假设 name 是一个路径,取最后一部分作为相对名称 # 去掉第二个下划线后面的内容 if '_' in item['name']: item['name'] = item['name'].split('.')[0].split('_')[0] + '_' + item['name'].split('.')[0].split('_')[1] # print(f"Processed eval item name: {item['name']}") for name, prompt in name2prompt.items(): if prompt not in prompt2item: print(f"Prompt not found in test set: {prompt}") continue # 找到 eval.json 里对应的条目(假设 eval.json 里有 name 字段) eval_entry = next((item for item in eval_items if item.get('name') == name), None) if eval_entry is None: print(f"Eval entry not found for name: {name}") continue # 原始条目 original_entry = prompt2item[prompt] results.append({ 'name': name, 'prompt': prompt, 'eval_entry': eval_entry, 'original_entry': original_entry }) print(f"Number of results: {len(results)}") print(f"Sample result: {results[0] if results else 'No results'}") def calculate_TBT_score(results): """ • Tempo Bin with Tolerance (TBT): The predicted bpm falls into the ground truth tempo bin or a neighboring one. """ correct = 0 total = 0 for result in results: eval_entry = result['eval_entry'] original_entry = result['original_entry'] if 'tempo' in eval_entry and 'tempo' in original_entry: eval_tempo = eval_entry['tempo'][0] if isinstance(eval_entry['tempo'], list) else eval_entry['tempo'] original_tempo = original_entry['tempo'] if original_tempo is None or eval_tempo is None: continue # 如果原始条目没有 tempo,跳过 # 检查 eval_tempo 是否在 original_tempo 的范围内 if original_tempo - 10 <= eval_tempo <= original_tempo + 15: correct += 1 total += 1 TB_score = correct / total if total > 0 else 0 print(f"TB Score: {TB_score:.4f} (Correct: {correct}, Total: {total})") return TB_score def calculate_CK_score(results): """ • Correct Key (CK): The predicted key matches the ground truth key. """ correct = 0 total = 0 for result in results: eval_entry = result['eval_entry'] original_entry = result['original_entry'] if 'key' in eval_entry and 'key' in original_entry: eval_key = eval_entry['key'][0] if isinstance(eval_entry['key'], list) else eval_entry['key'] eval_key = eval_key if eval_key is not None else "C major" # 默认值为 C 大调 original_key = original_entry['key'] if original_entry['key'] is not None else "C major" # 默认值为 C 大调 if original_key is None or eval_key is None: continue if eval_key == original_key: correct += 1 total += 1 CK_score = correct / total if total > 0 else 0 print(f"CK Score: {CK_score:.4f} (Correct: {correct}, Total: {total})") return CK_score def calculate_CKD_score(results): """ Correct Key with Duplicates (CKD): The predicted key matches the ground truth key or an equivalent key (i.e., a major key and its relative minor). """ correct = 0 total = 0 for result in results: eval_entry = result['eval_entry'] original_entry = result['original_entry'] if 'key' in eval_entry and 'key' in original_entry: eval_key = eval_entry['key'][0] if isinstance(eval_entry['key'], list) else eval_entry['key'] if eval_key is None: eval_key = "C major" # 默认值为 C 大调 original_key = original_entry['key'] if original_entry['key'] is not None else "C major" if original_key is None or eval_key is None: continue # 如果原始条目没有 key,跳过 # 检查 eval_key 是否与 original_key 相同或是其相对小调 if eval_key == original_key or (eval_key.split(' ')[0] == original_key.split(' ')[0]): correct += 1 total += 1 CKD_score = correct / total if total > 0 else 0 print(f"CKD Score: {CKD_score:.4f} (Correct: {correct}, Total: {total})") return CKD_score def calculate_CTS_score(results): """ • Correct Time Signature (CTS): The predicted time signature matches the ground truth time signature. """ correct = 0 total = 0 for result in results: eval_entry = result['eval_entry'] original_entry = result['original_entry'] if 'time_signature' in eval_entry and 'time_signature' in original_entry: eval_time_signature = eval_entry['time_signature'][0] if isinstance(eval_entry['time_signature'], list) else eval_entry['time_signature'] original_time_signature = original_entry['time_signature'] if original_time_signature is None or eval_time_signature is None: continue # 如果原始条目没有 time signature,跳过 if eval_time_signature == original_time_signature: correct += 1 else: # 检查是否为相同的节拍(如 4/4 和 2/2) eval_numerator, eval_denominator = map(int, eval_time_signature.split('/')) original_numerator, original_denominator = map(int, original_time_signature.split('/')) if (eval_numerator == original_numerator and eval_denominator == original_denominator) or \ (eval_numerator * 2 == original_numerator and eval_denominator == original_denominator): correct += 1 total += 1 CTS_score = correct / total if total > 0 else 0 print(f"CTS Score: {CTS_score:.4f} (Correct: {correct}, Total: {total})") return CTS_score def calculate_ECM_score(results): """ Exact Chord Match (ECM): The predicted chord sequence matches the ground truth exactly in terms of order, chord root, and chord type, with tolerance for missing and excess chord instances. """ correct = 0 total = 0 for result in results: eval_entry = result['eval_entry'] original_entry = result['original_entry'] if 'chord_summary' in eval_entry and 'chord_summary' in original_entry: eval_chord_summary = eval_entry['chord_summary'][0] if isinstance(eval_entry['chord_summary'], list) else eval_entry['chord_summary'] original_chord_summary = original_entry['chord_summary'] if original_chord_summary is None or eval_chord_summary is None: continue # 检查 eval_chord_summary 是否包含 original_chord_summary,两个都是列表,每个元素是一个字符串 if eval_chord_summary == original_chord_summary: correct += 1 total += 1 ECM_score = correct / total if total > 0 else 0 print(f"ECM Score: {ECM_score:.4f} (Correct: {correct}, Total: {total})") return ECM_score def calculate_CMO_score(results): """ • Chord Match in any Order (CMO): The portion of predicted chord sequence matching the ground truth chord root and type, in any order """ correct = 0 total = 0 for result in results: eval_entry = result['eval_entry'] original_entry = result['original_entry'] if 'chords' in eval_entry and 'chord_summary' in original_entry: eval_chords_seq = eval_entry['chords'] # remove the confidence score from eval_chords_seq if isinstance(eval_chords_seq, list) and len(eval_chords_seq) > 0 and isinstance(eval_chords_seq[0], list): eval_chords_seq = [chord[0] for chord in eval_chords_seq] original_chord_summary = original_entry['chord_summary'] if original_chord_summary is None or eval_chords_seq is None: continue # 检查 eval_chords_seq 是否包含 original_chord_summary,两个都是列表 eval_chords_set = set(eval_chords_seq) # [['C', 0.464399092], ['G', 2.879274376]] original_chord_set = set(original_chord_summary) # ['G', 'C'] if original_chord_set.issubset(eval_chords_set): correct += 1 else: if original_chord_set == eval_chords_set: correct += 1 total += 1 CMO_score = correct / total if total > 0 else 0 print(f"CMO Score: {CMO_score:.4f} (Correct: {correct}, Total: {total})") return CMO_score def calculate_CI_score(results): """ •Correct Instrument (CI): The predicted instrument matches the ground truth instrument. """ correct = 0 total = 0 for result in results: eval_entry = result['eval_entry'] original_entry = result['original_entry'] if 'mapped_instruments_summary' in eval_entry and 'instrument_summary' in original_entry: eval_instrument = eval_entry['mapped_instruments_summary'] if isinstance(eval_entry['mapped_instruments'], list) else eval_entry['mapped_instruments'] original_instrument = original_entry['instrument_summary'] if original_instrument is None or eval_instrument is None: continue # 检查 eval_instrument 是否包含 original_instrument if isinstance(eval_instrument, list): eval_instrument_set = set(eval_instrument) original_instrument_set = set(original_instrument) if original_instrument_set.issubset(eval_instrument_set): correct += 1 else: if eval_instrument == original_instrument: correct += 1 total += 1 CI_score = correct / total if total > 0 else 0 print(f"CI Score: {CI_score:.4f} (Correct: {correct}, Total: {total})") return CI_score def calculate_CI_top1_score(results): """ •Correct Instrument Top-1 (CI_top1): The predicted instrument matches the ground truth instrument or is one of the top 3 predicted instruments. """ correct = 0 total = 0 for result in results: eval_entry = result['eval_entry'] original_entry = result['original_entry'] if 'mapped_instruments_summary' in eval_entry and 'instrument_summary' in original_entry: eval_instrument = eval_entry['mapped_instruments_summary'] if isinstance(eval_entry['mapped_instruments'], list) else eval_entry['mapped_instruments'] original_instrument = original_entry['instrument_summary'] if original_instrument is None or eval_instrument is None: continue # 检查 eval_instrument 是否包含 original_instrument中的一个元素 if isinstance(eval_instrument, list): eval_instrument_set = set(eval_instrument) original_instrument_set = set(original_instrument) for inst in original_instrument_set: if inst in eval_instrument_set: correct += 1 break else: if eval_instrument == original_instrument: correct += 1 total += 1 CI_top1_score = correct / total if total > 0 else 0 print(f"CI Top-1 Score: {CI_top1_score:.4f} (Correct: {correct}, Total: {total})") return CI_top1_score def calculate_CG_score(results): """ • Correct Genre (CG): The predicted genre matches the ground truth genre. """ correct = 0 total = 0 for result in results: eval_entry = result['eval_entry'] original_entry = result['original_entry'] if 'genre' in eval_entry and 'genre' in original_entry: eval_genre = eval_entry['genre'][0] if isinstance(eval_entry['genre'], list) else eval_entry['genre'] original_genre = original_entry['genre'] if original_genre is None or eval_genre is None: continue # 检查 eval_genre 是否包含 original_genre if isinstance(eval_genre, list): eval_genre_set = set(eval_genre) original_genre_set = set(original_genre) if original_genre_set.issubset(eval_genre_set): correct += 1 else: if eval_genre == original_genre: correct += 1 total += 1 CG_score = correct / total if total > 0 else 0 print(f"CG Score: {CG_score:.4f} (Correct: {correct}, Total: {total})") return CG_score def calculate_CG_top1_score(results): """ • Correct Genre Top-1 (CG_top1): The predicted genre matches the ground truth genre or is one of the top 3 predicted genres. """ correct = 0 total = 0 for result in results: eval_entry = result['eval_entry'] original_entry = result['original_entry'] if 'genre' in eval_entry and 'genre' in original_entry: eval_genre = eval_entry['genre'][0] if isinstance(eval_entry['genre'], list) else eval_entry['genre'] original_genre = original_entry['genre'] if original_genre is None or eval_genre is None: continue # 检查 eval_genre 是否包含 original_genre中的一个元素 if isinstance(eval_genre, list): eval_genre_set = set(eval_genre) original_genre_set = set(original_genre) for gen in original_genre_set: if gen in eval_genre_set: correct += 1 break else: if eval_genre == original_genre: correct += 1 total += 1 CG_top1_score = correct / total if total > 0 else 0 print(f"CG Top-1 Score: {CG_top1_score:.4f} (Correct: {correct}, Total: {total})") return CG_top1_score def calculate_CM_score(results): """ • Correct Mood (CM): The predicted mood matches the ground truth mood. """ correct = 0 total = 0 for result in results: eval_entry = result['eval_entry'] original_entry = result['original_entry'] if 'mood' in eval_entry and 'mood' in original_entry: eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood'] original_mood = original_entry['mood'] if original_mood is None or eval_mood is None: continue # 检查 eval_mood 是否包含 original_mood if isinstance(eval_mood, list): eval_mood_set = set(eval_mood) original_mood_set = set(original_mood) if original_mood_set.issubset(eval_mood_set): correct += 1 else: if eval_mood == original_mood: correct += 1 total += 1 CM_score = correct / total if total > 0 else 0 print(f"CM Score: {CM_score:.4f} (Correct: {correct}, Total: {total})") return CM_score def calculate_CM_top1_score(results): """ • Correct Mood Top-1 (CM_top1): The predicted mood matches the ground truth mood or is one of the top 3 predicted moods. """ correct = 0 total = 0 for result in results: eval_entry = result['eval_entry'] original_entry = result['original_entry'] if 'mood' in eval_entry and 'mood' in original_entry: eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood'] original_mood = original_entry['mood'] if original_mood is None or eval_mood is None: continue # 检查 eval_mood 是否包含 original_mood中的一个元素 if isinstance(eval_mood, list): eval_mood_set = set(eval_mood) original_mood_set = set(original_mood) for mood in original_mood_set: if mood in eval_mood_set: correct += 1 break else: if eval_mood == original_mood: correct += 1 total += 1 CM_top1_score = correct / total if total > 0 else 0 print(f"CM Top-1 Score: {CM_top1_score:.4f} (Correct: {correct}, Total: {total})") return CM_top1_score def calculate_CM_top3_score(results): """ • Correct Mood Top-3 (CM_top3): The predicted mood matches the ground truth mood or is one of the top 3 predicted moods. """ correct = 0 total = 0 for result in results: eval_entry = result['eval_entry'] original_entry = result['original_entry'] if 'mood' in eval_entry and 'mood' in original_entry: eval_mood = eval_entry['mood'][0] if isinstance(eval_entry['mood'], list) else eval_entry['mood'] original_mood = original_entry['mood'] if original_mood is None or eval_mood is None: continue # 检查 eval_mood 是否包含 original_mood中的3个元素 if isinstance(eval_mood, list): eval_mood_set = set(eval_mood) original_mood_set = set(original_mood) if len(original_mood_set) <= 3 and original_mood_set.issubset(eval_mood_set): correct += 1 elif len(original_mood_set) > 3: match_num = sum(1 for mood in original_mood_set if mood in eval_mood_set) if match_num >= 3: correct += 1 else: if eval_mood == original_mood: correct += 1 total += 1 CM_top3_score = correct / total if total > 0 else 0 print(f"CM Top-3 Score: {CM_top3_score:.4f} (Correct: {correct}, Total: {total})") return CM_top3_score def calculate_all_scores(results): """ Calculate all scores and return them as a dictionary. """ scores = { 'TBT_score': calculate_TBT_score(results), 'CK_score': calculate_CK_score(results), 'CKD_score': calculate_CKD_score(results), 'CTS_score': calculate_CTS_score(results), 'ECM_score': calculate_ECM_score(results), 'CMO_score': calculate_CMO_score(results), 'CI_score': calculate_CI_score(results), 'CI_top1_score': calculate_CI_top1_score(results), 'CG_score': calculate_CG_score(results), 'CG_top1_score': calculate_CG_top1_score(results), 'CM_score': calculate_CM_score(results), 'CM_top1_score': calculate_CM_top1_score(results), 'CM_top3_score': calculate_CM_top3_score(results) } return scores if __name__ == "__main__": scores = calculate_all_scores(results) print("All Scores:") for score_name, score_value in scores.items(): print(f"{score_name}: {score_value:.4f}") # Save the results to a JSON file output_file = f"{generate_path}/results.json" with open(output_file, 'w') as f: json.dump(scores, f, indent=4) print(f"Results saved to {output_file}")