1127 update to latest
This commit is contained in:
@ -74,8 +74,8 @@ class LanguageModelTrainer:
|
||||
sampling_threshold: float, # Threshold for sampling decisions
|
||||
sampling_temperature: float, # Temperature for controlling sampling randomness
|
||||
config, # Configuration parameters (contains general, training, and inference settings)
|
||||
model_checkpoint="wandb/run-20251025_104202-kd5cf5b3/files/checkpoints/iter42612_loss-8.9870.pt", # Path to a pre-trained model checkpoint (optional)
|
||||
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
|
||||
# model_checkpoint="wandb/run-20251114_151512-k21rnynj/files/checkpoints/iter104999_loss0.2490.pt", # Path to a pre-trained model checkpoint (optional)
|
||||
model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
|
||||
):
|
||||
# Save model, optimizer, and other configurations
|
||||
self.model = model
|
||||
@ -892,6 +892,10 @@ class LanguageModelTrainer4CompoundToken(LanguageModelTrainer):
|
||||
segment, mask, caption,encoded_caption = batch
|
||||
input_seq, target = segment[:, :-1], segment[:, 1:]
|
||||
total_loss, logits_dict, loss_dict = self._get_loss_pred_from_single_batch(batch, valid=True)
|
||||
try:
|
||||
aux_ar_logits, logits_dict = logits_dict
|
||||
except:
|
||||
logits_dict = logits_dict
|
||||
probs_dict = {key:torch.softmax(value, dim=-1) for key, value in logits_dict.items()}
|
||||
num_nonmask_tokens = torch.sum(mask)
|
||||
input_seq = input_seq.to(self.device)
|
||||
|
||||
Reference in New Issue
Block a user