add gitignore

This commit is contained in:
FelixChan
2025-09-25 15:17:59 +08:00
parent 83707ea927
commit a85731d254
3 changed files with 398 additions and 177 deletions

View File

@ -268,7 +268,10 @@ class SymbolicMusicDataset(Dataset):
def _get_split_list_from_tune_in_idx(self, ratio, seed):
# Split the dataset into train, validation, and test sets based on the given ratio
shuffled_tune_names = list(self.tune_in_idx.keys()) # Get the list of all tune names
try:
shuffled_tune_names = list(self.tune_in_idx.keys()) # Get the list of all tune names
except:
shuffled_tune_names = []
random.seed(seed) # Set the seed for reproducibility
random.shuffle(shuffled_tune_names) # Shuffle the tune names
@ -413,7 +416,7 @@ class LakhClean(SymbolicMusicDataset):
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class LakhClean(SymbolicMusicDataset):
class chorus(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
@ -470,6 +473,124 @@ class LakhClean(SymbolicMusicDataset):
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class Melody(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class IrishMan(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
try:
shuffled_tune_names = list(self.tune_in_idx.keys())
except:
shuffled_tune_names = []
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class ariamidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
@ -527,6 +648,123 @@ class ariamidi(SymbolicMusicDataset):
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class gigamidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class ariamidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
try:
shuffled_tune_names = list(self.tune_in_idx.keys())
except:
shuffled_tune_names = []
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class gigamidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
@ -1081,11 +1319,7 @@ class LakhALLFined(SymbolicMusicDataset):
'''
# filter out none in tune_in_idx
print("length of tune_in_idx before filtering:", len(self.tune_in_idx))
try:
self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None}
except:
print("Error filtering None values in tune_in_idx, skipping filtering")
return [], [], [], []
self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None}
print("length of tune_in_idx after filtering:", len(self.tune_in_idx))
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
@ -1579,11 +1813,7 @@ class FinetuneDataset(SymbolicMusicDataset):
'''
# filter out none in tune_in_idx
print("length of tune_in_idx before filtering:", len(self.tune_in_idx))
try:
self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None}
except:
print("Error filtering None values in tune_in_idx, skipping filtering")
return [], [], [], []
self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None}
print("length of tune_in_idx after filtering:", len(self.tune_in_idx))
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]

View File

@ -389,85 +389,6 @@ class XtransformerCrossAttendDecoder(nn.Module):
else:
return self.transformer_decoder(seq, context=context)
class XtransformerLargeCrossAttendDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
cross_attend = True,
only_cross = False)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state
else:
context = context_embedding
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq, context=context)
class NewCrossAttendDecoder(nn.Module):
def __init__(
@ -638,6 +559,75 @@ class NewCrossAttendwithRoPEDecoder(nn.Module):
else:
return self.transformer_decoder(seq, context=context)
class RoPEDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
# self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
# cross_attend = True,
only_cross = False,
use_rmsnorm=True,
rotary_pos_emb = True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq)
class XtransformerPrefixDecoder(nn.Module):
def __init__(
self,
@ -711,7 +701,80 @@ class XtransformerPrefixDecoder(nn.Module):
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq)
class XtransformerNewPretrainingDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
use_rmsnorm=True,
rotary_pos_emb = True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
# shift_tokens = 1,
# attn_qk_norm = True,
# attn_qk_norm_dim_scale = True
)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None, context_embedding=None):
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq)
class XtransformerPretrainingDecoder(nn.Module):
def __init__(
self,
@ -827,92 +890,6 @@ class XtransformerFinetuningDecoder(nn.Module):
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
else:
context = context_embedding
# concatenate context with seq
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
# cut to only return the seq part
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
# cut to only return the seq part
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec, intermediates
else:
# cut to only return the seq part
hidden_vec = self.transformer_decoder(seq)
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec
class XtransformerLargeFinetuningDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None: