1029 add octuple

This commit is contained in:
Mars
2025-10-29 17:14:33 +08:00
parent b493ede479
commit e16c84aab2
22 changed files with 1135 additions and 62 deletions

View File

@ -74,7 +74,7 @@ class LanguageModelTrainer:
sampling_threshold: float, # Threshold for sampling decisions
sampling_temperature: float, # Temperature for controlling sampling randomness
config, # Configuration parameters (contains general, training, and inference settings)
model_checkpoint="wandb/run-20251016_180043-70ihsi93/files/checkpoints/iter80999_loss0.0300.pt", # Path to a pre-trained model checkpoint (optional)
model_checkpoint="wandb/run-20251025_104202-kd5cf5b3/files/checkpoints/iter42612_loss-8.9870.pt", # Path to a pre-trained model checkpoint (optional)
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
):
# Save model, optimizer, and other configurations
@ -104,7 +104,6 @@ class LanguageModelTrainer:
checkpoint = torch.load(model_checkpoint, map_location='cpu')
# print state dict keys
print("Loading model checkpoint from", model_checkpoint)
print("Checkpoint keys:", checkpoint['model'].keys())
if isinstance(self.model, DDP):
self.model.module.load_state_dict(checkpoint['model'], strict=False)
else:
@ -902,9 +901,9 @@ class LanguageModelTrainer4CompoundToken(LanguageModelTrainer):
correct_guess_by_feature = defaultdict(int)
num_tokens_by_feature = defaultdict(int)
for idx, key in enumerate(self.vocab.feature_list):
if key == 'type':
if key == 'type' or key == 'timesig' :
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=None, conti_token=None)
elif key == 'chord' or key == 'tempo' or key == 'instrument':
elif key == 'chord' or key == 'tempo' or key == 'instrument' or key == 'program':
num_correct_tokens, num_valid_tokens = self._get_num_valid_and_correct_tokens(probs_dict[key], target[..., idx], mask, ignore_token=0, conti_token=9999)
elif key == 'beat':
# NB's beat vocab has Ignore and CONTI token