first commit
This commit is contained in:
208
Amadeus/symbolic_encoding/metric_utils.py
Normal file
208
Amadeus/symbolic_encoding/metric_utils.py
Normal file
@ -0,0 +1,208 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from collections import Counter
|
||||
|
||||
# TODO: refactor hard coded values
|
||||
def check_syntax_errors_in_inference_for_nb(generated_output, feature_list):
|
||||
generated_output = generated_output.squeeze(0)
|
||||
type_idx = feature_list.index('type')
|
||||
beat_idx = feature_list.index('beat')
|
||||
type_beat_list = []
|
||||
for token in generated_output:
|
||||
type_beat_list.append((token[type_idx].item(), token[beat_idx].item())) # type, beat
|
||||
|
||||
last_note = 1
|
||||
beat_type_unmatched_error_list = []
|
||||
num_unmatched_errors = 0
|
||||
beat_backwards_error_list = []
|
||||
num_backwards_errors = 0
|
||||
for type_beat in type_beat_list:
|
||||
if type_beat[0] == 4: # same bar, new beat
|
||||
if type_beat[1] == 0 or type_beat[1] == 1:
|
||||
num_unmatched_errors += 1
|
||||
beat_type_unmatched_error_list.append(type_beat)
|
||||
if type_beat[1] <= last_note:
|
||||
num_backwards_errors += 1
|
||||
beat_backwards_error_list.append([last_note, type_beat])
|
||||
else:
|
||||
last_note = type_beat[1] # update last note
|
||||
elif type_beat[0] >= 5: # new bar, new beat
|
||||
if type_beat[1] == 0:
|
||||
num_unmatched_errors += 1
|
||||
beat_type_unmatched_error_list.append(type_beat)
|
||||
last_note = 1
|
||||
unmatched_error_rate = num_unmatched_errors / len(type_beat_list)
|
||||
backwards_error_rate = num_backwards_errors / len(type_beat_list)
|
||||
type_beat_errors_dict = {'beat_type_unmatched_error': unmatched_error_rate, 'beat_backwards_error': backwards_error_rate}
|
||||
return type_beat_errors_dict
|
||||
|
||||
def check_syntax_errors_in_inference_for_cp(generated_output, feature_list):
|
||||
generated_output = generated_output.squeeze(0)
|
||||
type_idx = feature_list.index('type')
|
||||
beat_idx = feature_list.index('beat')
|
||||
pitch_idx = feature_list.index('pitch')
|
||||
duration_idx = feature_list.index('duration')
|
||||
last_note = 1
|
||||
beat_type_unmatched_error_list = []
|
||||
num_unmatched_errors = 0
|
||||
beat_backwards_error_list = []
|
||||
num_backwards_errors = 0
|
||||
for token in generated_output:
|
||||
if token[type_idx].item() == 2: # Metrical
|
||||
if token[pitch_idx].item() != 0 or token[duration_idx].item() != 0:
|
||||
num_unmatched_errors += 1
|
||||
beat_type_unmatched_error_list.append(token)
|
||||
if token[beat_idx].item() == 1: # new bar
|
||||
last_note = 1 # last note will be updated in the next token
|
||||
elif token[beat_idx].item() != 0 and token[beat_idx].item() <= last_note:
|
||||
num_backwards_errors += 1
|
||||
last_note = token[beat_idx].item() # update last note
|
||||
beat_backwards_error_list.append([last_note, token])
|
||||
else:
|
||||
last_note = token[beat_idx].item() # update last note
|
||||
if token[type_idx].item() == 3: # Note
|
||||
if token[beat_idx].item() != 0:
|
||||
num_unmatched_errors += 1
|
||||
beat_type_unmatched_error_list.append(token)
|
||||
unmatched_error_rate = num_unmatched_errors / len(generated_output)
|
||||
backwards_error_rate = num_backwards_errors / len(generated_output)
|
||||
type_beat_errors_dict = {'beat_type_unmatched_error': unmatched_error_rate, 'beat_backwards_error': backwards_error_rate}
|
||||
return type_beat_errors_dict
|
||||
|
||||
def check_syntax_errors_in_inference_for_remi(generated_output, vocab):
|
||||
generated_output = generated_output.squeeze(0)
|
||||
# to check duration errors
|
||||
beat_mask = vocab.total_mask['beat'].to(generated_output.device)
|
||||
beat_mask_for_target = beat_mask[generated_output]
|
||||
beat_target = generated_output * beat_mask_for_target
|
||||
bar_mask = vocab.total_mask['type'].to(generated_output.device)
|
||||
bar_mask_for_target = bar_mask[generated_output]
|
||||
bar_target = (generated_output+1) * bar_mask_for_target # as bar token in 0 in remi vocab, we add 1 to bar token
|
||||
target = beat_target + bar_target
|
||||
target = target[target!=0]
|
||||
# collect beats in between bars(idx=1)
|
||||
num_backwards_errors = 0
|
||||
collected_beats = []
|
||||
total_beats = 0
|
||||
for token in target:
|
||||
if token == 1 or 3 <= token <= 26: # Bar_None, or Bar_time_signature
|
||||
collected_beats_tensor = torch.tensor(collected_beats)
|
||||
diff = torch.diff(collected_beats_tensor)
|
||||
num_error_beats = torch.where(diff<=0)[0].shape[0]
|
||||
num_backwards_errors += num_error_beats
|
||||
collected_beats = []
|
||||
else:
|
||||
collected_beats.append(token.item())
|
||||
total_beats += 1
|
||||
if total_beats != 0:
|
||||
backwards_error_rate = num_backwards_errors / total_beats
|
||||
else:
|
||||
backwards_error_rate = 0
|
||||
# print(f"error rate in beat backwards: {backwards_error_rate}")
|
||||
return {'beat_backwards_error': backwards_error_rate}
|
||||
|
||||
def type_beat_errors_in_validation_nb(beat_prob, answer_type, input_beat, mask):
|
||||
bool_mask = mask.bool().flatten() # (b*t)
|
||||
pred_beat_idx = torch.argmax(beat_prob, dim=-1).flatten() # (b*t)
|
||||
valid_pred_beat_idx = pred_beat_idx[bool_mask] # valid beat_idx
|
||||
answer_type = answer_type.flatten() # (b*t)
|
||||
valid_type_input = answer_type[bool_mask] # valid answer_type
|
||||
type_beat_list = []
|
||||
for i in range(len(valid_pred_beat_idx)):
|
||||
type_beat_list.append((valid_type_input[i].item(), valid_pred_beat_idx[i].item())) # type, beat
|
||||
input_beat = input_beat.flatten()
|
||||
valid_input_beat = input_beat[bool_mask]
|
||||
|
||||
last_note = 1
|
||||
num_unmatched_errors = 0
|
||||
num_backwards_errors = 0
|
||||
for type_beat, input_beat_idx in zip(type_beat_list, valid_input_beat):
|
||||
# update last note
|
||||
if input_beat_idx.item() >= 1: # beat
|
||||
last_note = input_beat_idx.item()
|
||||
if type_beat[0] == 4: # same bar, new beat
|
||||
if type_beat[1] == 0 or type_beat[1] == 1:
|
||||
num_unmatched_errors += 1
|
||||
if type_beat[1] <= last_note:
|
||||
num_backwards_errors += 1
|
||||
elif type_beat[0] >= 5: # new bar, new beat
|
||||
if type_beat[1] == 0:
|
||||
num_unmatched_errors += 1
|
||||
return len(type_beat_list), num_unmatched_errors, num_backwards_errors
|
||||
|
||||
def type_beat_errors_in_validation_cp(beat_prob, answer_type, input_beat, mask):
|
||||
bool_mask = mask.bool().flatten() # (b*t)
|
||||
beat_idx = torch.argmax(beat_prob, dim=-1).flatten() # (b*t)
|
||||
valid_beat_idx = beat_idx[bool_mask] # valid beat_idx
|
||||
answer_type = answer_type.flatten() # (b*t)
|
||||
valid_type_input = answer_type[bool_mask] # valid answer_type
|
||||
type_beat_list = []
|
||||
for i in range(len(valid_beat_idx)):
|
||||
type_beat_list.append((valid_type_input[i].item(), valid_beat_idx[i].item())) # type, beat
|
||||
input_beat = input_beat.flatten()
|
||||
valid_input_beat = input_beat[bool_mask]
|
||||
|
||||
last_note = 1
|
||||
num_unmatched_errors = 0
|
||||
num_backwards_errors = 0
|
||||
for type_beat, input_beat_idx in zip(type_beat_list, valid_input_beat):
|
||||
# update last note
|
||||
if input_beat_idx.item() == 1: # bar
|
||||
last_note = 1
|
||||
elif input_beat_idx.item() >= 2: # new beat
|
||||
last_note = input_beat_idx.item()
|
||||
# check errors
|
||||
if type_beat[0] == 2: # Metrical
|
||||
if type_beat[1] == 0: # ignore
|
||||
num_unmatched_errors += 1
|
||||
elif type_beat[1] >= 2: # new beat
|
||||
if type_beat[1] <= last_note:
|
||||
num_backwards_errors += 1
|
||||
elif type_beat[0] == 3: # Note
|
||||
if type_beat[1] != 0:
|
||||
num_unmatched_errors += 1
|
||||
return len(type_beat_list), num_unmatched_errors, num_backwards_errors
|
||||
|
||||
def get_beat_difference_metric(prob_dict, arranged_prob_dict, mask):
|
||||
orign_beat_prob = prob_dict['beat'] # b x t x vocab_size
|
||||
arranged_beat_prob = arranged_prob_dict['beat'] # b x t x vocab_size
|
||||
|
||||
# calculate similarity between original beat prob and arranged beat prob
|
||||
origin_beat_token = torch.argmax(orign_beat_prob, dim=-1) * mask # b x t
|
||||
arranged_beat_token = torch.argmax(arranged_beat_prob, dim=-1) * mask # b x t
|
||||
num_same_beat = torch.sum(origin_beat_token == arranged_beat_token) - torch.sum(mask==0)
|
||||
num_beat = torch.sum(mask==1)
|
||||
beat_sim = (num_same_beat / num_beat).item() # scalar
|
||||
|
||||
# apply mask, shape of mask: b x t
|
||||
orign_beat_prob = orign_beat_prob * mask.unsqueeze(-1) # b x t x vocab_size
|
||||
arranged_beat_prob = arranged_beat_prob * mask.unsqueeze(-1)
|
||||
|
||||
# calculate cosine similarity between original beat prob and arranged beat prob
|
||||
orign_beat_prob = orign_beat_prob.flatten(0,1) # (b*t) x vocab_size
|
||||
arranged_beat_prob = arranged_beat_prob.flatten(0,1) # (b*t) x vocab_size
|
||||
cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
|
||||
beat_cos_sim = cos(orign_beat_prob, arranged_beat_prob) # (b*t)
|
||||
# exclude invalid tokens, zero padding tokens
|
||||
beat_cos_sim = beat_cos_sim[mask.flatten().bool()] # num_valid_tokens
|
||||
beat_cos_sim = torch.mean(beat_cos_sim).item() # scalar
|
||||
return {'beat_cos_sim': beat_cos_sim, 'beat_sim': beat_sim}
|
||||
|
||||
def get_gini_coefficient(generated_output):
|
||||
if len(generated_output.shape) == 3:
|
||||
generated_output = generated_output.squeeze(0).tolist()
|
||||
gen_list = [tuple(x) for x in generated_output]
|
||||
else:
|
||||
gen_list = generated_output.squeeze(0).tolist()
|
||||
counts = Counter(gen_list).values()
|
||||
sorted_counts = sorted(counts)
|
||||
n = len(sorted_counts)
|
||||
cumulative_counts = np.cumsum(sorted_counts)
|
||||
cumulative_proportion = cumulative_counts / cumulative_counts[-1]
|
||||
|
||||
lorenz_area = sum(cumulative_proportion[:-1]) / n # Exclude the last element
|
||||
equality_area = 0.5 # The area under line of perfect equality
|
||||
|
||||
gini = (equality_area - lorenz_area) / equality_area
|
||||
return gini
|
||||
Reference in New Issue
Block a user