1029 add octuple
This commit is contained in:
@ -73,7 +73,7 @@ class VanillaTransformer_compiler():
|
||||
for i in range(len(self.data_list)):
|
||||
tune_in_idx, tune_name = self.data_list[i]
|
||||
tune_in_idx = torch.LongTensor(tune_in_idx)
|
||||
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
|
||||
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp' or self.encoding_scheme == 'oct':
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
else:
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
@ -148,7 +148,7 @@ class VanillaTransformer_compiler():
|
||||
for i in range(len(self.data_list)):
|
||||
tune_in_idx, tune_name = self.data_list[i]
|
||||
tune_in_idx = torch.LongTensor(tune_in_idx)
|
||||
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp':
|
||||
if self.encoding_scheme == 'remi' or self.encoding_scheme == 'cp' or self.encoding_scheme == 'oct':
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
else:
|
||||
eos_token = torch.LongTensor(self.eos_token)
|
||||
|
||||
Reference in New Issue
Block a user