1127 update to latest

This commit is contained in:
FelixChan
2025-11-27 15:44:17 +08:00
parent e16c84aab2
commit a34d39430e
153 changed files with 25705 additions and 53 deletions

View File

@ -74,8 +74,8 @@ class LanguageModelTrainer:
sampling_threshold: float, # Threshold for sampling decisions
sampling_temperature: float, # Temperature for controlling sampling randomness
config, # Configuration parameters (contains general, training, and inference settings)
model_checkpoint="wandb/run-20251025_104202-kd5cf5b3/files/checkpoints/iter42612_loss-8.9870.pt", # Path to a pre-trained model checkpoint (optional)
# model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
# model_checkpoint="wandb/run-20251114_151512-k21rnynj/files/checkpoints/iter104999_loss0.2490.pt", # Path to a pre-trained model checkpoint (optional)
model_checkpoint: Union[str, None] = None, # Path to a pre-trainmodl checkpoint (optional)
):
# Save model, optimizer, and other configurations
self.model = model
@ -892,6 +892,10 @@ class LanguageModelTrainer4CompoundToken(LanguageModelTrainer):
segment, mask, caption,encoded_caption = batch
input_seq, target = segment[:, :-1], segment[:, 1:]
total_loss, logits_dict, loss_dict = self._get_loss_pred_from_single_batch(batch, valid=True)
try:
aux_ar_logits, logits_dict = logits_dict
except:
logits_dict = logits_dict
probs_dict = {key:torch.softmax(value, dim=-1) for key, value in logits_dict.items()}
num_nonmask_tokens = torch.sum(mask)
input_seq = input_seq.to(self.device)