46 lines
1.6 KiB
Python
46 lines
1.6 KiB
Python
import sys
|
|
import torch
|
|
from pathlib import Path
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from Amadeus.evaluation_utils import Evaluator, wandb_style_config_to_omega_config, prepare_model_and_dataset_from_config, get_best_ckpt_path_and_config
|
|
|
|
def main(exp_code):
|
|
wandb_dir = Path('wandb')
|
|
ckpt_path, config_path, metadata_path, vocab_path = get_best_ckpt_path_and_config(wandb_dir, exp_code)
|
|
config = OmegaConf.load(config_path)
|
|
config = wandb_style_config_to_omega_config(config)
|
|
print(ckpt_path)
|
|
|
|
ckpt = torch.load(ckpt_path, map_location='cpu')
|
|
model, test_set, vocab = prepare_model_and_dataset_from_config(config, metadata_path=metadata_path, vocab_path=vocab_path)
|
|
model.load_state_dict(ckpt['model'])
|
|
model = model.eval()
|
|
|
|
evaluator = Evaluator(config, model, test_set, vocab, device='cuda', batch_size=21)
|
|
|
|
evaluator.get_perplexity()
|
|
evaluator.save_results(wandb_dir / exp_code / f'micro_evaluated_perplexity_conti_fixed.pt')
|
|
mean_by_class = {}
|
|
|
|
for key in evaluator.vocab.feature_list:
|
|
# skip type for calculating mean as type or metric token have different meanings across encoding schemes
|
|
if key == 'type':
|
|
continue
|
|
mean_nll = sum(evaluator.loss_by_class[key]) / evaluator.count_by_class[key]
|
|
mean_by_class[key] = mean_nll
|
|
|
|
# calculate micro average
|
|
total_mean_nll = 0
|
|
for key in mean_by_class.keys():
|
|
total_mean_nll += mean_by_class[key] * evaluator.count_by_class[key]
|
|
denominator = 0
|
|
for key in mean_by_class.keys():
|
|
denominator += evaluator.count_by_class[key]
|
|
total_mean_nll /= denominator
|
|
return total_mean_nll
|
|
|
|
if __name__ == '__main__':
|
|
exp_code = sys.argv[1]
|
|
print(main(exp_code)) |