1029 add octuple
This commit is contained in:
@ -95,11 +95,6 @@ class TuneCompiler(Dataset):
|
||||
print(f"Error encoding caption for tune {tune_name}: {e}")
|
||||
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
|
||||
return segment, tensor_mask, tune_name, encoded_caption
|
||||
if self.data_type == 'train':
|
||||
augmented_segment = self.augmentor(segment)
|
||||
return augmented_segment, tensor_mask, tune_name, encoded_caption
|
||||
else:
|
||||
return segment, tensor_mask, tune_name, encoded_caption
|
||||
|
||||
def get_segments_with_tune_idx(self, tune_name, seg_order):
|
||||
'''
|
||||
@ -135,6 +130,7 @@ class IterTuneCompiler(IterableDataset):
|
||||
self.data_type = data_type
|
||||
self.augmentor = augmentor
|
||||
self.eos_token = vocab.eos_token
|
||||
self.vocab = vocab
|
||||
self.compile_function = VanillaTransformer_compiler(
|
||||
data_list=self.data_list,
|
||||
augmentor=self.augmentor,
|
||||
@ -157,7 +153,7 @@ class IterTuneCompiler(IterableDataset):
|
||||
encoded_caption = self.t5_tokenizer(tune_name, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
|
||||
except Exception as e:
|
||||
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
|
||||
if self.data_type == 'train':
|
||||
if self.data_type == 'train' and self.vocab.encoding_scheme != 'oct':
|
||||
segment = self.augmentor(segment)
|
||||
# use input_ids replace tune_name
|
||||
tune_name = encoded_caption['input_ids'][0] # Use the input_ids from the encoded caption
|
||||
|
||||
Reference in New Issue
Block a user