From a85731d254e327efbb630d3a81fd372ecedf9771 Mon Sep 17 00:00:00 2001 From: FelixChan <2223485532@qq,com> Date: Thu, 25 Sep 2025 15:17:59 +0800 Subject: [PATCH] add gitignore --- .gitignore | 14 ++ Amadeus/symbolic_encoding/data_utils.py | 254 +++++++++++++++++++- Amadeus/transformer_utils.py | 307 +++++++++++------------- 3 files changed, 398 insertions(+), 177 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..eeb1def --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +REMI-tempo-chord-checkpoint/ +REMI_decoded/ +dataset +vocab +__pycache__/ +analysis/ +outputs/ +pre_trained/ +wandb/ +*.csv +*.pyc +*.pkl +.vscode/ +checkpoints/ \ No newline at end of file diff --git a/Amadeus/symbolic_encoding/data_utils.py b/Amadeus/symbolic_encoding/data_utils.py index 2400d36..0d23e84 100644 --- a/Amadeus/symbolic_encoding/data_utils.py +++ b/Amadeus/symbolic_encoding/data_utils.py @@ -268,7 +268,10 @@ class SymbolicMusicDataset(Dataset): def _get_split_list_from_tune_in_idx(self, ratio, seed): # Split the dataset into train, validation, and test sets based on the given ratio - shuffled_tune_names = list(self.tune_in_idx.keys()) # Get the list of all tune names + try: + shuffled_tune_names = list(self.tune_in_idx.keys()) # Get the list of all tune names + except: + shuffled_tune_names = [] random.seed(seed) # Set the seed for reproducibility random.shuffle(shuffled_tune_names) # Shuffle the tune names @@ -413,7 +416,7 @@ class LakhClean(SymbolicMusicDataset): test_names.extend(song_dict[song_name]) return train_names, valid_names, test_names, shuffled_tune_names -class LakhClean(SymbolicMusicDataset): +class chorus(SymbolicMusicDataset): def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None, for_evaluation: bool = False): super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path, @@ -470,6 +473,124 @@ class LakhClean(SymbolicMusicDataset): test_names.extend(song_dict[song_name]) return train_names, valid_names, test_names, shuffled_tune_names +class Melody(SymbolicMusicDataset): + def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None, + for_evaluation: bool = False): + super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path, + for_evaluation=for_evaluation) + + def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]: + ''' + Irregular tunes are removed from the dataset for better generation quality + It includes tunes that are not quantized properly, mostly theay are expressive performance data + ''' + print("preprocessed tune_in_idx data is being loaded") + tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz"))) + if self.debug: + tune_in_idx_list = tune_in_idx_list[:5000] + tune_in_idx_dict = OrderedDict() + len_tunes = OrderedDict() + file_name_list = [] + with open("metadata/LakhClean_irregular_tunes.json", "r") as f: + irregular_tunes = json.load(f) + for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)): + if tune_in_idx_file.stem in irregular_tunes: + continue + tune_in_idx = np.load(tune_in_idx_file)['arr_0'] + tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx + len_tunes[tune_in_idx_file.stem] = len(tune_in_idx) + file_name_list.append(tune_in_idx_file.stem) + print(f"number of loaded tunes: {len(tune_in_idx_dict)}") + return tune_in_idx_dict, len_tunes, file_name_list + + def _get_split_list_from_tune_in_idx(self, ratio, seed): + ''' + As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name + ''' + shuffled_tune_names = list(self.tune_in_idx.keys()) + song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names] + song_dict = {} + for song, orig_song in zip(song_names_without_version, shuffled_tune_names): + if song not in song_dict: + song_dict[song] = [] + song_dict[song].append(orig_song) + unique_song_names = list(song_dict.keys()) + random.seed(seed) + random.shuffle(unique_song_names) + num_train = int(len(unique_song_names)*ratio) + num_valid = int(len(unique_song_names)*(1-ratio)/2) + train_names = [] + valid_names = [] + test_names = [] + for song_name in unique_song_names[:num_train]: + train_names.extend(song_dict[song_name]) + for song_name in unique_song_names[num_train:num_train+num_valid]: + valid_names.extend(song_dict[song_name]) + for song_name in unique_song_names[num_train+num_valid:]: + test_names.extend(song_dict[song_name]) + return train_names, valid_names, test_names, shuffled_tune_names + + +class IrishMan(SymbolicMusicDataset): + def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None, + for_evaluation: bool = False): + super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path, + for_evaluation=for_evaluation) + + def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]: + ''' + Irregular tunes are removed from the dataset for better generation quality + It includes tunes that are not quantized properly, mostly theay are expressive performance data + ''' + print("preprocessed tune_in_idx data is being loaded") + tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz"))) + if self.debug: + tune_in_idx_list = tune_in_idx_list[:5000] + tune_in_idx_dict = OrderedDict() + len_tunes = OrderedDict() + file_name_list = [] + with open("metadata/LakhClean_irregular_tunes.json", "r") as f: + irregular_tunes = json.load(f) + for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)): + if tune_in_idx_file.stem in irregular_tunes: + continue + tune_in_idx = np.load(tune_in_idx_file)['arr_0'] + tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx + len_tunes[tune_in_idx_file.stem] = len(tune_in_idx) + file_name_list.append(tune_in_idx_file.stem) + print(f"number of loaded tunes: {len(tune_in_idx_dict)}") + return tune_in_idx_dict, len_tunes, file_name_list + + def _get_split_list_from_tune_in_idx(self, ratio, seed): + ''' + As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name + ''' + try: + shuffled_tune_names = list(self.tune_in_idx.keys()) + except: + shuffled_tune_names = [] + song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names] + song_dict = {} + for song, orig_song in zip(song_names_without_version, shuffled_tune_names): + if song not in song_dict: + song_dict[song] = [] + song_dict[song].append(orig_song) + unique_song_names = list(song_dict.keys()) + random.seed(seed) + random.shuffle(unique_song_names) + num_train = int(len(unique_song_names)*ratio) + num_valid = int(len(unique_song_names)*(1-ratio)/2) + train_names = [] + valid_names = [] + test_names = [] + for song_name in unique_song_names[:num_train]: + train_names.extend(song_dict[song_name]) + for song_name in unique_song_names[num_train:num_train+num_valid]: + valid_names.extend(song_dict[song_name]) + for song_name in unique_song_names[num_train+num_valid:]: + test_names.extend(song_dict[song_name]) + return train_names, valid_names, test_names, shuffled_tune_names + class ariamidi(SymbolicMusicDataset): def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None, for_evaluation: bool = False): @@ -527,6 +648,123 @@ class ariamidi(SymbolicMusicDataset): test_names.extend(song_dict[song_name]) return train_names, valid_names, test_names, shuffled_tune_names +class gigamidi(SymbolicMusicDataset): + def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None, + for_evaluation: bool = False): + super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path, + for_evaluation=for_evaluation) + + def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]: + ''' + Irregular tunes are removed from the dataset for better generation quality + It includes tunes that are not quantized properly, mostly theay are expressive performance data + ''' + print("preprocessed tune_in_idx data is being loaded") + tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz"))) + if self.debug: + tune_in_idx_list = tune_in_idx_list[:5000] + tune_in_idx_dict = OrderedDict() + len_tunes = OrderedDict() + file_name_list = [] + with open("metadata/LakhClean_irregular_tunes.json", "r") as f: + irregular_tunes = json.load(f) + for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)): + if tune_in_idx_file.stem in irregular_tunes: + continue + tune_in_idx = np.load(tune_in_idx_file)['arr_0'] + tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx + len_tunes[tune_in_idx_file.stem] = len(tune_in_idx) + file_name_list.append(tune_in_idx_file.stem) + print(f"number of loaded tunes: {len(tune_in_idx_dict)}") + return tune_in_idx_dict, len_tunes, file_name_list + + def _get_split_list_from_tune_in_idx(self, ratio, seed): + ''' + As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name + ''' + shuffled_tune_names = list(self.tune_in_idx.keys()) + song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names] + song_dict = {} + for song, orig_song in zip(song_names_without_version, shuffled_tune_names): + if song not in song_dict: + song_dict[song] = [] + song_dict[song].append(orig_song) + unique_song_names = list(song_dict.keys()) + random.seed(seed) + random.shuffle(unique_song_names) + num_train = int(len(unique_song_names)*ratio) + num_valid = int(len(unique_song_names)*(1-ratio)/2) + train_names = [] + valid_names = [] + test_names = [] + for song_name in unique_song_names[:num_train]: + train_names.extend(song_dict[song_name]) + for song_name in unique_song_names[num_train:num_train+num_valid]: + valid_names.extend(song_dict[song_name]) + for song_name in unique_song_names[num_train+num_valid:]: + test_names.extend(song_dict[song_name]) + return train_names, valid_names, test_names, shuffled_tune_names + +class ariamidi(SymbolicMusicDataset): + def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None, + for_evaluation: bool = False): + super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path, + for_evaluation=for_evaluation) + + def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]: + ''' + Irregular tunes are removed from the dataset for better generation quality + It includes tunes that are not quantized properly, mostly theay are expressive performance data + ''' + print("preprocessed tune_in_idx data is being loaded") + tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz"))) + if self.debug: + tune_in_idx_list = tune_in_idx_list[:5000] + tune_in_idx_dict = OrderedDict() + len_tunes = OrderedDict() + file_name_list = [] + with open("metadata/LakhClean_irregular_tunes.json", "r") as f: + irregular_tunes = json.load(f) + for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)): + if tune_in_idx_file.stem in irregular_tunes: + continue + tune_in_idx = np.load(tune_in_idx_file)['arr_0'] + tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx + len_tunes[tune_in_idx_file.stem] = len(tune_in_idx) + file_name_list.append(tune_in_idx_file.stem) + print(f"number of loaded tunes: {len(tune_in_idx_dict)}") + return tune_in_idx_dict, len_tunes, file_name_list + + def _get_split_list_from_tune_in_idx(self, ratio, seed): + ''' + As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name + ''' + try: + shuffled_tune_names = list(self.tune_in_idx.keys()) + except: + shuffled_tune_names = [] + song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names] + song_dict = {} + for song, orig_song in zip(song_names_without_version, shuffled_tune_names): + if song not in song_dict: + song_dict[song] = [] + song_dict[song].append(orig_song) + unique_song_names = list(song_dict.keys()) + random.seed(seed) + random.shuffle(unique_song_names) + num_train = int(len(unique_song_names)*ratio) + num_valid = int(len(unique_song_names)*(1-ratio)/2) + train_names = [] + valid_names = [] + test_names = [] + for song_name in unique_song_names[:num_train]: + train_names.extend(song_dict[song_name]) + for song_name in unique_song_names[num_train:num_train+num_valid]: + valid_names.extend(song_dict[song_name]) + for song_name in unique_song_names[num_train+num_valid:]: + test_names.extend(song_dict[song_name]) + return train_names, valid_names, test_names, shuffled_tune_names + class gigamidi(SymbolicMusicDataset): def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None, for_evaluation: bool = False): @@ -1081,11 +1319,7 @@ class LakhALLFined(SymbolicMusicDataset): ''' # filter out none in tune_in_idx print("length of tune_in_idx before filtering:", len(self.tune_in_idx)) - try: - self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None} - except: - print("Error filtering None values in tune_in_idx, skipping filtering") - return [], [], [], [] + self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None} print("length of tune_in_idx after filtering:", len(self.tune_in_idx)) shuffled_tune_names = list(self.tune_in_idx.keys()) song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names] @@ -1579,11 +1813,7 @@ class FinetuneDataset(SymbolicMusicDataset): ''' # filter out none in tune_in_idx print("length of tune_in_idx before filtering:", len(self.tune_in_idx)) - try: - self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None} - except: - print("Error filtering None values in tune_in_idx, skipping filtering") - return [], [], [], [] + self.tune_in_idx = {k: v for k, v in self.tune_in_idx.items() if v is not None} print("length of tune_in_idx after filtering:", len(self.tune_in_idx)) shuffled_tune_names = list(self.tune_in_idx.keys()) song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names] diff --git a/Amadeus/transformer_utils.py b/Amadeus/transformer_utils.py index 3982341..d30e1b6 100644 --- a/Amadeus/transformer_utils.py +++ b/Amadeus/transformer_utils.py @@ -389,85 +389,6 @@ class XtransformerCrossAttendDecoder(nn.Module): else: return self.transformer_decoder(seq, context=context) -class XtransformerLargeCrossAttendDecoder(nn.Module): - def __init__( - self, - dim:int, - depth:int, - heads:int, - dropout:float - ): - super().__init__() - self._make_decoder_layer(dim, depth, heads, dropout) - self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large') - # frozen text encoder - for param in self.text_encoder.parameters(): - param.requires_grad = False - - def _make_decoder_layer(self, dim, depth, heads, dropout): - self.transformer_decoder = Decoder( - dim = dim, - depth = depth, - heads = heads, - attn_dropout = dropout, - ff_dropout = dropout, - attn_flash = True, - cross_attend = True, - only_cross = False) - # add final dropout - print('Applying Xavier Uniform Init to x-transformer following torch.Transformer') - self._apply_xavier_init() - print('Adding dropout after feedforward layer in x-transformer') - self._add_dropout_after_ff(dropout) - print('Adding dropout after attention layer in x-transformer') - self._add_dropout_after_attn(dropout) - - def _add_dropout_after_attn(self, dropout): - for layer in self.transformer_decoder.layers: - if 'Attention' in str(type(layer[1])): - if isinstance(layer[1].to_out, nn.Sequential): # if GLU - layer[1].to_out.append(nn.Dropout(dropout)) - elif isinstance(layer[1].to_out, nn.Linear): # if simple linear - layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout)) - else: - raise ValueError('to_out should be either nn.Sequential or nn.Linear') - - def _add_dropout_after_ff(self, dropout): - for layer in self.transformer_decoder.layers: - if 'FeedForward' in str(type(layer[1])): - layer[1].ff.append(nn.Dropout(dropout)) - - def _apply_xavier_init(self): - for name, param in self.transformer_decoder.named_parameters(): - if 'to_q' in name or 'to_k' in name or 'to_v' in name: - torch.nn.init.xavier_uniform_(param, gain=0.5**0.5) - - def forward(self, seq, cache=None,train=False,context=None,context_embedding=None): - assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder' - if context_embedding is None: - input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids'] - attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask'] - assert input_ids is not None, 'input_ids should be provided for prefix decoder' - assert attention_mask is not None, 'attention_mask should be provided for prefix decoder' - assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder' - - context = self.text_encoder( - input_ids=input_ids, - attention_mask=attention_mask - ).last_hidden_state - else: - context = context_embedding - - if cache is not None: # implementing run_one_step in inference - if cache.hiddens is None: cache = None - hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context) - return hidden_vec, intermediates - else: - if train: - hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True) - return hidden_vec, intermediates - else: - return self.transformer_decoder(seq, context=context) class NewCrossAttendDecoder(nn.Module): def __init__( @@ -638,6 +559,75 @@ class NewCrossAttendwithRoPEDecoder(nn.Module): else: return self.transformer_decoder(seq, context=context) +class RoPEDecoder(nn.Module): + def __init__( + self, + dim:int, + depth:int, + heads:int, + dropout:float + ): + super().__init__() + self._make_decoder_layer(dim, depth, heads, dropout) + # self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base') + # frozen text encoder + + def _make_decoder_layer(self, dim, depth, heads, dropout): + self.transformer_decoder = Decoder( + dim = dim, + depth = depth, + heads = heads, + attn_dropout = dropout, + ff_dropout = dropout, + attn_flash = True, + # cross_attend = True, + only_cross = False, + use_rmsnorm=True, + rotary_pos_emb = True, + ff_swish = True, # set this to True + ff_glu = True, # set to true to use for all feedforwards + ) + # add final dropout + print('Applying Xavier Uniform Init to x-transformer following torch.Transformer') + self._apply_xavier_init() + print('Adding dropout after feedforward layer in x-transformer') + self._add_dropout_after_ff(dropout) + print('Adding dropout after attention layer in x-transformer') + self._add_dropout_after_attn(dropout) + + def _add_dropout_after_attn(self, dropout): + for layer in self.transformer_decoder.layers: + if 'Attention' in str(type(layer[1])): + if isinstance(layer[1].to_out, nn.Sequential): # if GLU + layer[1].to_out.append(nn.Dropout(dropout)) + elif isinstance(layer[1].to_out, nn.Linear): # if simple linear + layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout)) + else: + raise ValueError('to_out should be either nn.Sequential or nn.Linear') + + def _add_dropout_after_ff(self, dropout): + for layer in self.transformer_decoder.layers: + if 'FeedForward' in str(type(layer[1])): + layer[1].ff.append(nn.Dropout(dropout)) + + def _apply_xavier_init(self): + for name, param in self.transformer_decoder.named_parameters(): + if 'to_q' in name or 'to_k' in name or 'to_v' in name: + torch.nn.init.xavier_uniform_(param, gain=0.5**0.5) + + def forward(self, seq, cache=None,train=False,context=None,context_embedding=None): + if cache is not None: # implementing run_one_step in inference + if cache.hiddens is None: cache = None + hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True) + return hidden_vec, intermediates + else: + if train: + hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True) + return hidden_vec, intermediates + else: + return self.transformer_decoder(seq) + + class XtransformerPrefixDecoder(nn.Module): def __init__( self, @@ -711,7 +701,80 @@ class XtransformerPrefixDecoder(nn.Module): return hidden_vec, intermediates else: return self.transformer_decoder(seq) + +class XtransformerNewPretrainingDecoder(nn.Module): + def __init__( + self, + dim:int, + depth:int, + heads:int, + dropout:float + ): + super().__init__() + self._make_decoder_layer(dim, depth, heads, dropout) + self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base') + # frozen text encoder + for param in self.text_encoder.parameters(): + param.requires_grad = False + def _make_decoder_layer(self, dim, depth, heads, dropout): + self.transformer_decoder = Decoder( + dim = dim, + depth = depth, + heads = heads, + attn_dropout = dropout, + ff_dropout = dropout, + attn_flash = True, + use_rmsnorm=True, + rotary_pos_emb = True, + ff_swish = True, # set this to True + ff_glu = True, # set to true to use for all feedforwards + # shift_tokens = 1, + # attn_qk_norm = True, + # attn_qk_norm_dim_scale = True + ) + # add final dropout + print('Applying Xavier Uniform Init to x-transformer following torch.Transformer') + self._apply_xavier_init() + print('Adding dropout after feedforward layer in x-transformer') + self._add_dropout_after_ff(dropout) + print('Adding dropout after attention layer in x-transformer') + self._add_dropout_after_attn(dropout) + + def _add_dropout_after_attn(self, dropout): + for layer in self.transformer_decoder.layers: + if 'Attention' in str(type(layer[1])): + if isinstance(layer[1].to_out, nn.Sequential): # if GLU + layer[1].to_out.append(nn.Dropout(dropout)) + elif isinstance(layer[1].to_out, nn.Linear): # if simple linear + layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout)) + else: + raise ValueError('to_out should be either nn.Sequential or nn.Linear') + + def _add_dropout_after_ff(self, dropout): + for layer in self.transformer_decoder.layers: + if 'FeedForward' in str(type(layer[1])): + layer[1].ff.append(nn.Dropout(dropout)) + + def _apply_xavier_init(self): + for name, param in self.transformer_decoder.named_parameters(): + if 'to_q' in name or 'to_k' in name or 'to_v' in name: + torch.nn.init.xavier_uniform_(param, gain=0.5**0.5) + + def forward(self, seq, cache=None,train=False,context=None, context_embedding=None): + + if cache is not None: # implementing run_one_step in inference + if cache.hiddens is None: cache = None + hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True) + return hidden_vec, intermediates + else: + if train: + hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True) + return hidden_vec, intermediates + else: + return self.transformer_decoder(seq) + + class XtransformerPretrainingDecoder(nn.Module): def __init__( self, @@ -827,92 +890,6 @@ class XtransformerFinetuningDecoder(nn.Module): if 'to_q' in name or 'to_k' in name or 'to_v' in name: torch.nn.init.xavier_uniform_(param, gain=0.5**0.5) - def forward(self, seq, cache=None,train=False,context=None,context_embedding=None): - assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder' - if context_embedding is None: - input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids'] - attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask'] - assert input_ids is not None, 'input_ids should be provided for prefix decoder' - assert attention_mask is not None, 'attention_mask should be provided for prefix decoder' - assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder' - - context = self.text_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - ).last_hidden_state - else: - context = context_embedding - - # concatenate context with seq - seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size - if cache is not None: # implementing run_one_step in inference - if cache.hiddens is None: cache = None - hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True) - # cut to only return the seq part - return hidden_vec, intermediates - else: - if train: - hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True) - # cut to only return the seq part - hidden_vec = hidden_vec[:, context.shape[1]:, :] - return hidden_vec, intermediates - else: - # cut to only return the seq part - hidden_vec = self.transformer_decoder(seq) - hidden_vec = hidden_vec[:, context.shape[1]:, :] - return hidden_vec - -class XtransformerLargeFinetuningDecoder(nn.Module): - def __init__( - self, - dim:int, - depth:int, - heads:int, - dropout:float - ): - super().__init__() - self._make_decoder_layer(dim, depth, heads, dropout) - self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large') - # frozen text encoder - for param in self.text_encoder.parameters(): - param.requires_grad = False - - def _make_decoder_layer(self, dim, depth, heads, dropout): - self.transformer_decoder = Decoder( - dim = dim, - depth = depth, - heads = heads, - attn_dropout = dropout, - ff_dropout = dropout, - attn_flash = True) - # add final dropout - print('Applying Xavier Uniform Init to x-transformer following torch.Transformer') - self._apply_xavier_init() - print('Adding dropout after feedforward layer in x-transformer') - self._add_dropout_after_ff(dropout) - print('Adding dropout after attention layer in x-transformer') - self._add_dropout_after_attn(dropout) - - def _add_dropout_after_attn(self, dropout): - for layer in self.transformer_decoder.layers: - if 'Attention' in str(type(layer[1])): - if isinstance(layer[1].to_out, nn.Sequential): # if GLU - layer[1].to_out.append(nn.Dropout(dropout)) - elif isinstance(layer[1].to_out, nn.Linear): # if simple linear - layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout)) - else: - raise ValueError('to_out should be either nn.Sequential or nn.Linear') - - def _add_dropout_after_ff(self, dropout): - for layer in self.transformer_decoder.layers: - if 'FeedForward' in str(type(layer[1])): - layer[1].ff.append(nn.Dropout(dropout)) - - def _apply_xavier_init(self): - for name, param in self.transformer_decoder.named_parameters(): - if 'to_q' in name or 'to_k' in name or 'to_v' in name: - torch.nn.init.xavier_uniform_(param, gain=0.5**0.5) - def forward(self, seq, cache=None,train=False,context=None,context_embedding=None): assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder' if context_embedding is None: