1127 update to latest

This commit is contained in:
FelixChan
2025-11-27 15:44:17 +08:00
parent e16c84aab2
commit a34d39430e
153 changed files with 25705 additions and 53 deletions

View File

@ -99,9 +99,12 @@ def compare_pair(file_a: str, file_b: str):
def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8):
files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")]
files_a = files_a[:100] # 仅比较前100个文件以节省时间
# remove files end with _prompt.mid
files_a = [f for f in files_a if not f.endswith("_prompt.mid")]
files_a = files_a
files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")]
results = []
pbar = tqdm(total=len(files_a) * len(files_b), desc="Comparing MIDI files")
with ProcessPoolExecutor(max_workers=max_workers) as executor:
@ -110,6 +113,8 @@ def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv",
pbar.update(1)
try:
results.append(fut.result())
if results[-1][2] == 0:
print(f"Exact match found: {results[-1][0]} and {results[-1][1]}")
except Exception as e:
print(fut.result())
print(f"Error comparing pair: {e}")
@ -129,6 +134,6 @@ def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv",
if __name__ == "__main__":
dir_a = "wandb/run-20251027_161354-f9j1mwp2/uncond_min_p_t0.05_temp1.25_epochch8"
dir_a = "wandb/run-20251124_104410-bjdyzt85ar_aux_melody/uncond_min_p_t0.2_temp1.25"
dir_b = "dataset/Melody"
batch_compare(dir_a, dir_b, out_csv="midi_similarity_v2.csv", max_workers=6)
batch_compare(dir_a, dir_b, out_csv="midi_similarity_withbase_p0.6.csv", max_workers=24)