1029 add octuple

This commit is contained in:
Mars
2025-10-29 17:14:33 +08:00
parent b493ede479
commit e16c84aab2
22 changed files with 1135 additions and 62 deletions

View File

@ -95,11 +95,6 @@ class TuneCompiler(Dataset):
print(f"Error encoding caption for tune {tune_name}: {e}")
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
return segment, tensor_mask, tune_name, encoded_caption
if self.data_type == 'train':
augmented_segment = self.augmentor(segment)
return augmented_segment, tensor_mask, tune_name, encoded_caption
else:
return segment, tensor_mask, tune_name, encoded_caption
def get_segments_with_tune_idx(self, tune_name, seg_order):
'''
@ -135,6 +130,7 @@ class IterTuneCompiler(IterableDataset):
self.data_type = data_type
self.augmentor = augmentor
self.eos_token = vocab.eos_token
self.vocab = vocab
self.compile_function = VanillaTransformer_compiler(
data_list=self.data_list,
augmentor=self.augmentor,
@ -157,7 +153,7 @@ class IterTuneCompiler(IterableDataset):
encoded_caption = self.t5_tokenizer(tune_name, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
except Exception as e:
encoded_caption = self.t5_tokenizer("No caption available", return_tensors='pt', padding='max_length', truncation=True, max_length=128)
if self.data_type == 'train':
if self.data_type == 'train' and self.vocab.encoding_scheme != 'oct':
segment = self.augmentor(segment)
# use input_ids replace tune_name
tune_name = encoded_caption['input_ids'][0] # Use the input_ids from the encoded caption