1127 update to latest
This commit is contained in:
@ -3,6 +3,7 @@ import random
|
||||
from pathlib import Path
|
||||
from collections import OrderedDict
|
||||
from typing import Union, List, Tuple, Dict
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
@ -157,6 +158,8 @@ class IterTuneCompiler(IterableDataset):
|
||||
segment = self.augmentor(segment)
|
||||
# use input_ids replace tune_name
|
||||
tune_name = encoded_caption['input_ids'][0] # Use the input_ids from the encoded caption
|
||||
# print(segment.shape, mask.shape, tune_name.shape)
|
||||
# segment = segment[torch.randperm(segment.size(0))]
|
||||
yield segment, mask, tune_name, encoded_caption
|
||||
|
||||
def __len__(self):
|
||||
|
||||
Reference in New Issue
Block a user