1029 add octuple
This commit is contained in:
@ -74,7 +74,7 @@ class LanguageModelTrainer:
|
||||
sampling_threshold: float, # Threshold for sampling decisions
|
||||
sampling_temperature: float, # Temperature for controlling sampling randomness
|
||||
config, # Configuration parameters (contains general, training, and inference settings)
|
||||
model_checkpoint="wandb/run-20251016_180043-70ihsi93/files/checkpoints/iter80999_loss0.0300.pt", # Path to a pre-trained model checkpoint (optional)
|
||||
model_checkpoint="wandb/run-20251025_104202-kd5cf5b3/files/checkpoints/iter42612_loss-8.9870.pt", # Path to a pre-trained model checkpoint (optional)
|
||||
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
|
||||
):
|
||||
# Save model, optimizer, and other configurations
|
||||
@ -104,7 +104,6 @@ class LanguageModelTrainer:
|
||||
checkpoint = torch.load(model_checkpoint, map_location='cpu')
|
||||
# print state dict keys
|
||||
print("Loading model checkpoint from", model_checkpoint)
|
||||
print("Checkpoint keys:", checkpoint['model'].keys())
|
||||
if isinstance(self.model, DDP):
|
||||
self.model.module.load_state_dict(checkpoint['model'], strict=False)
|
||||
else:
|
||||
@ -902,9 +901,9 @@ class LanguageModelTrainer4CompoundToken(LanguageModelTrainer):
|
||||
correct_guess_by_feature = defaultdict(int)
|
||||
num_tokens_by_feature = defaultdict(int)
|
||||
for idx, key in enumerate(self.vocab.feature_list):
|
||||
if key == 'type':
|
||||
if key == 'type' or key == 'timesig' :
|
||||
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=None, conti_token=None)
|
||||
elif key == 'chord' or key == 'tempo' or key == 'instrument':
|
||||
elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program':
|
||||
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=9999)
|
||||
elif key == 'beat':
|
||||
# NB's beat vocab has Ignore and CONTI token
|
||||
|
||||
Reference in New Issue
Block a user