1127 update to latest
This commit is contained in:
11
midi_sim.py
11
midi_sim.py
@ -99,9 +99,12 @@ def compare_pair(file_a: str, file_b: str):
|
||||
|
||||
def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv", max_workers: int = 8):
|
||||
files_a = [os.path.join(dir_a, f) for f in os.listdir(dir_a) if f.endswith(".mid")]
|
||||
files_a = files_a[:100] # 仅比较前100个文件以节省时间
|
||||
# remove files end with _prompt.mid
|
||||
files_a = [f for f in files_a if not f.endswith("_prompt.mid")]
|
||||
files_a = files_a
|
||||
files_b = [os.path.join(dir_b, f) for f in os.listdir(dir_b) if f.endswith(".mid")]
|
||||
|
||||
|
||||
results = []
|
||||
pbar = tqdm(total=len(files_a) * len(files_b), desc="Comparing MIDI files")
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
@ -110,6 +113,8 @@ def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv",
|
||||
pbar.update(1)
|
||||
try:
|
||||
results.append(fut.result())
|
||||
if results[-1][2] == 0:
|
||||
print(f"Exact match found: {results[-1][0]} and {results[-1][1]}")
|
||||
except Exception as e:
|
||||
print(fut.result())
|
||||
print(f"Error comparing pair: {e}")
|
||||
@ -129,6 +134,6 @@ def batch_compare(dir_a: str, dir_b: str, out_csv: str = "midi_similarity.csv",
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dir_a = "wandb/run-20251027_161354-f9j1mwp2/uncond_min_p_t0.05_temp1.25_epochch8"
|
||||
dir_a = "wandb/run-20251124_104410-bjdyzt85ar_aux_melody/uncond_min_p_t0.2_temp1.25"
|
||||
dir_b = "dataset/Melody"
|
||||
batch_compare(dir_a, dir_b, out_csv="midi_similarity_v2.csv", max_workers=6)
|
||||
batch_compare(dir_a, dir_b, out_csv="midi_similarity_withbase_p0.6.csv", max_workers=24)
|
||||
Reference in New Issue
Block a user