add gitignore

This commit is contained in:
FelixChan
2025-09-25 15:17:59 +08:00
parent 83707ea927
commit a85731d254
3 changed files with 398 additions and 177 deletions

14
.gitignore vendored Normal file
View File

@ -0,0 +1,14 @@
REMI-tempo-chord-checkpoint/
REMI_decoded/
dataset
vocab
__pycache__/
analysis/
outputs/
pre_trained/
wandb/
*.csv
*.pyc
*.pkl
.vscode/
checkpoints/

View File

@ -268,7 +268,10 @@ class SymbolicMusicDataset(Dataset):
def _get_split_list_from_tune_in_idx(self, ratio, seed): def _get_split_list_from_tune_in_idx(self, ratio, seed):
# Split the dataset into train, validation, and test sets based on the given ratio # Split the dataset into train, validation, and test sets based on the given ratio
shuffled_tune_names = list(self.tune_in_idx.keys()) # Get the list of all tune names try:
shuffled_tune_names = list(self.tune_in_idx.keys()) # Get the list of all tune names
except:
shuffled_tune_names = []
random.seed(seed) # Set the seed for reproducibility random.seed(seed) # Set the seed for reproducibility
random.shuffle(shuffled_tune_names) # Shuffle the tune names random.shuffle(shuffled_tune_names) # Shuffle the tune names
@ -413,7 +416,239 @@ class LakhClean(SymbolicMusicDataset):
test_names.extend(song_dict[song_name]) test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names return train_names, valid_names, test_names, shuffled_tune_names
class LakhClean(SymbolicMusicDataset): class chorus(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class Melody(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class IrishMan(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
try:
shuffled_tune_names = list(self.tune_in_idx.keys())
except:
shuffled_tune_names = []
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class ariamidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
for_evaluation=for_evaluation)
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
'''
Irregular tunes are removed from the dataset for better generation quality
It includes tunes that are not quantized properly, mostly theay are expressive performance data
'''
print("preprocessed tune_in_idx data is being loaded")
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
if self.debug:
tune_in_idx_list = tune_in_idx_list[:5000]
tune_in_idx_dict = OrderedDict()
len_tunes = OrderedDict()
file_name_list = []
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
irregular_tunes = json.load(f)
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
if tune_in_idx_file.stem in irregular_tunes:
continue
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
file_name_list.append(tune_in_idx_file.stem)
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
return tune_in_idx_dict, len_tunes, file_name_list
def _get_split_list_from_tune_in_idx(self, ratio, seed):
'''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
'''
shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
if song not in song_dict:
song_dict[song] = []
song_dict[song].append(orig_song)
unique_song_names = list(song_dict.keys())
random.seed(seed)
random.shuffle(unique_song_names)
num_train = int(len(unique_song_names)*ratio)
num_valid = int(len(unique_song_names)*(1-ratio)/2)
train_names = []
valid_names = []
test_names = []
for song_name in unique_song_names[:num_train]:
train_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train:num_train+num_valid]:
valid_names.extend(song_dict[song_name])
for song_name in unique_song_names[num_train+num_valid:]:
test_names.extend(song_dict[song_name])
return train_names, valid_names, test_names, shuffled_tune_names
class gigamidi(SymbolicMusicDataset):
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None, def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
for_evaluation: bool = False): for_evaluation: bool = False):
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path, super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
@ -504,7 +739,10 @@ class ariamidi(SymbolicMusicDataset):
''' '''
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
''' '''
shuffled_tune_names = list(self.tune_in_idx.keys()) try:
shuffled_tune_names = list(self.tune_in_idx.keys())
except:
shuffled_tune_names = []
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names] song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
song_dict = {} song_dict = {}
for song, orig_song in zip(song_names_without_version, shuffled_tune_names): for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
@ -1081,11 +1319,7 @@ class LakhALLFined(SymbolicMusicDataset):
''' '''
# filter out none in tune_in_idx # filter out none in tune_in_idx
print("length of tune_in_idx before filtering:", len(self.tune_in_idx)) print("length of tune_in_idx before filtering:", len(self.tune_in_idx))
try: self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None}
self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None}
except:
print("Error filtering None values in tune_in_idx, skipping filtering")
return [], [], [], []
print("length of tune_in_idx after filtering:", len(self.tune_in_idx)) print("length of tune_in_idx after filtering:", len(self.tune_in_idx))
shuffled_tune_names = list(self.tune_in_idx.keys()) shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names] song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
@ -1579,11 +1813,7 @@ class FinetuneDataset(SymbolicMusicDataset):
''' '''
# filter out none in tune_in_idx # filter out none in tune_in_idx
print("length of tune_in_idx before filtering:", len(self.tune_in_idx)) print("length of tune_in_idx before filtering:", len(self.tune_in_idx))
try: self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None}
self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None}
except:
print("Error filtering None values in tune_in_idx, skipping filtering")
return [], [], [], []
print("length of tune_in_idx after filtering:", len(self.tune_in_idx)) print("length of tune_in_idx after filtering:", len(self.tune_in_idx))
shuffled_tune_names = list(self.tune_in_idx.keys()) shuffled_tune_names = list(self.tune_in_idx.keys())
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names] song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]

View File

@ -389,85 +389,6 @@ class XtransformerCrossAttendDecoder(nn.Module):
else: else:
return self.transformer_decoder(seq, context=context) return self.transformer_decoder(seq, context=context)
class XtransformerLargeCrossAttendDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
cross_attend = True,
only_cross = False)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask
).last_hidden_state
else:
context = context_embedding
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq, context=context)
class NewCrossAttendDecoder(nn.Module): class NewCrossAttendDecoder(nn.Module):
def __init__( def __init__(
@ -638,6 +559,75 @@ class NewCrossAttendwithRoPEDecoder(nn.Module):
else: else:
return self.transformer_decoder(seq, context=context) return self.transformer_decoder(seq, context=context)
class RoPEDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
# self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
# cross_attend = True,
only_cross = False,
use_rmsnorm=True,
rotary_pos_emb = True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq)
class XtransformerPrefixDecoder(nn.Module): class XtransformerPrefixDecoder(nn.Module):
def __init__( def __init__(
self, self,
@ -712,6 +702,79 @@ class XtransformerPrefixDecoder(nn.Module):
else: else:
return self.transformer_decoder(seq) return self.transformer_decoder(seq)
class XtransformerNewPretrainingDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True,
use_rmsnorm=True,
rotary_pos_emb = True,
ff_swish = True, # set this to True
ff_glu = True, # set to true to use for all feedforwards
# shift_tokens = 1,
# attn_qk_norm = True,
# attn_qk_norm_dim_scale = True
)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None, context_embedding=None):
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
return hidden_vec, intermediates
else:
return self.transformer_decoder(seq)
class XtransformerPretrainingDecoder(nn.Module): class XtransformerPretrainingDecoder(nn.Module):
def __init__( def __init__(
self, self,
@ -861,89 +924,3 @@ class XtransformerFinetuningDecoder(nn.Module):
hidden_vec = self.transformer_decoder(seq) hidden_vec = self.transformer_decoder(seq)
hidden_vec = hidden_vec[:, context.shape[1]:, :] hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec return hidden_vec
class XtransformerLargeFinetuningDecoder(nn.Module):
def __init__(
self,
dim:int,
depth:int,
heads:int,
dropout:float
):
super().__init__()
self._make_decoder_layer(dim, depth, heads, dropout)
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large')
# frozen text encoder
for param in self.text_encoder.parameters():
param.requires_grad = False
def _make_decoder_layer(self, dim, depth, heads, dropout):
self.transformer_decoder = Decoder(
dim = dim,
depth = depth,
heads = heads,
attn_dropout = dropout,
ff_dropout = dropout,
attn_flash = True)
# add final dropout
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
self._apply_xavier_init()
print('Adding dropout after feedforward layer in x-transformer')
self._add_dropout_after_ff(dropout)
print('Adding dropout after attention layer in x-transformer')
self._add_dropout_after_attn(dropout)
def _add_dropout_after_attn(self, dropout):
for layer in self.transformer_decoder.layers:
if 'Attention' in str(type(layer[1])):
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
layer[1].to_out.append(nn.Dropout(dropout))
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
else:
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
def _add_dropout_after_ff(self, dropout):
for layer in self.transformer_decoder.layers:
if 'FeedForward' in str(type(layer[1])):
layer[1].ff.append(nn.Dropout(dropout))
def _apply_xavier_init(self):
for name, param in self.transformer_decoder.named_parameters():
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
if context_embedding is None:
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
context = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
else:
context = context_embedding
# concatenate context with seq
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
if cache is not None: # implementing run_one_step in inference
if cache.hiddens is None: cache = None
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
# cut to only return the seq part
return hidden_vec, intermediates
else:
if train:
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
# cut to only return the seq part
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec, intermediates
else:
# cut to only return the seq part
hidden_vec = self.transformer_decoder(seq)
hidden_vec = hidden_vec[:, context.shape[1]:, :]
return hidden_vec