1013 update
This commit is contained in:
@ -1358,6 +1358,7 @@ class Attention(Module):
|
||||
dim_latent_kv = None,
|
||||
latent_rope_subheads = None,
|
||||
onnxable = False,
|
||||
use_gated_attention = False, # https://arxiv.org/abs/2505.06708
|
||||
attend_sdp_kwargs: dict = dict(
|
||||
enable_flash = True,
|
||||
enable_math = True,
|
||||
@ -1387,6 +1388,7 @@ class Attention(Module):
|
||||
k_dim = dim_head * kv_heads
|
||||
v_dim = value_dim_head * kv_heads
|
||||
out_dim = value_dim_head * heads
|
||||
gated_dim = out_dim
|
||||
|
||||
# determine input dimensions to qkv based on whether intermediate latent q and kv are being used
|
||||
# for eventually supporting multi-latent attention (MLA)
|
||||
@ -1447,7 +1449,8 @@ class Attention(Module):
|
||||
|
||||
self.to_v_gate = None
|
||||
if gate_values:
|
||||
self.to_v_gate = nn.Linear(dim, out_dim)
|
||||
# self.to_v_gate = nn.Linear(dim, out_dim)
|
||||
self.to_v_gate = nn.Linear(dim_kv_input, gated_dim)
|
||||
self.to_v_gate_activation = F.silu if swiglu_values else F.sigmoid
|
||||
nn.init.constant_(self.to_v_gate.weight, 0)
|
||||
nn.init.constant_(self.to_v_gate.bias, 10)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def top_p_sampling(logits, thres=0.9):
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
@ -84,7 +84,7 @@ def sample_with_prob(logits, sampling_method, threshold, temperature):
|
||||
# temporarily apply the sampling method to logits
|
||||
logits = logits / temperature
|
||||
# logits = add_gumbel_noise(logits, temperature)
|
||||
|
||||
|
||||
if sampling_method == "top_p":
|
||||
modified_logits = top_p_sampling(logits, thres=threshold)
|
||||
elif sampling_method == "typical":
|
||||
|
||||
@ -530,6 +530,62 @@ class Melody(SymbolicMusicDataset):
|
||||
test_names.extend(song_dict[song_name])
|
||||
return train_names, valid_names, test_names, shuffled_tune_names
|
||||
|
||||
class msmidi(SymbolicMusicDataset):
|
||||
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
|
||||
for_evaluation: bool = False):
|
||||
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
|
||||
for_evaluation=for_evaluation)
|
||||
|
||||
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
|
||||
'''
|
||||
Irregular tunes are removed from the dataset for better generation quality
|
||||
It includes tunes that are not quantized properly, mostly theay are expressive performance data
|
||||
'''
|
||||
print("preprocessed tune_in_idx data is being loaded")
|
||||
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
|
||||
if self.debug:
|
||||
tune_in_idx_list = tune_in_idx_list[:5000]
|
||||
tune_in_idx_dict = OrderedDict()
|
||||
len_tunes = OrderedDict()
|
||||
file_name_list = []
|
||||
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
|
||||
irregular_tunes = json.load(f)
|
||||
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
|
||||
if tune_in_idx_file.stem in irregular_tunes:
|
||||
continue
|
||||
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
|
||||
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
|
||||
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
|
||||
file_name_list.append(tune_in_idx_file.stem)
|
||||
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
|
||||
return tune_in_idx_dict, len_tunes, file_name_list
|
||||
|
||||
def _get_split_list_from_tune_in_idx(self, ratio, seed):
|
||||
'''
|
||||
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
|
||||
'''
|
||||
shuffled_tune_names = list(self.tune_in_idx.keys())
|
||||
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
|
||||
song_dict = {}
|
||||
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
|
||||
if song not in song_dict:
|
||||
song_dict[song] = []
|
||||
song_dict[song].append(orig_song)
|
||||
unique_song_names = list(song_dict.keys())
|
||||
random.seed(seed)
|
||||
random.shuffle(unique_song_names)
|
||||
num_train = int(len(unique_song_names)*ratio)
|
||||
num_valid = int(len(unique_song_names)*(1-ratio)/2)
|
||||
train_names = []
|
||||
valid_names = []
|
||||
test_names = []
|
||||
for song_name in unique_song_names[:num_train]:
|
||||
train_names.extend(song_dict[song_name])
|
||||
for song_name in unique_song_names[num_train:num_train+num_valid]:
|
||||
valid_names.extend(song_dict[song_name])
|
||||
for song_name in unique_song_names[num_train+num_valid:]:
|
||||
test_names.extend(song_dict[song_name])
|
||||
return train_names, valid_names, test_names, shuffled_tune_names
|
||||
|
||||
class IrishMan(SymbolicMusicDataset):
|
||||
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
|
||||
@ -648,62 +704,62 @@ class ariamidi(SymbolicMusicDataset):
|
||||
test_names.extend(song_dict[song_name])
|
||||
return train_names, valid_names, test_names, shuffled_tune_names
|
||||
|
||||
class gigamidi(SymbolicMusicDataset):
|
||||
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
|
||||
for_evaluation: bool = False):
|
||||
super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
|
||||
for_evaluation=for_evaluation)
|
||||
# class gigamidi(SymbolicMusicDataset):
|
||||
# def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
|
||||
# for_evaluation: bool = False):
|
||||
# super().__init__(vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path,
|
||||
# for_evaluation=for_evaluation)
|
||||
|
||||
def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
|
||||
'''
|
||||
Irregular tunes are removed from the dataset for better generation quality
|
||||
It includes tunes that are not quantized properly, mostly theay are expressive performance data
|
||||
'''
|
||||
print("preprocessed tune_in_idx data is being loaded")
|
||||
tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
|
||||
if self.debug:
|
||||
tune_in_idx_list = tune_in_idx_list[:5000]
|
||||
tune_in_idx_dict = OrderedDict()
|
||||
len_tunes = OrderedDict()
|
||||
file_name_list = []
|
||||
with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
|
||||
irregular_tunes = json.load(f)
|
||||
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
|
||||
if tune_in_idx_file.stem in irregular_tunes:
|
||||
continue
|
||||
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
|
||||
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
|
||||
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
|
||||
file_name_list.append(tune_in_idx_file.stem)
|
||||
print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
|
||||
return tune_in_idx_dict, len_tunes, file_name_list
|
||||
# def _load_tune_in_idx(self) -> Tuple[Dict[str, np.ndarray], Dict[str, int], List[str]]:
|
||||
# '''
|
||||
# Irregular tunes are removed from the dataset for better generation quality
|
||||
# It includes tunes that are not quantized properly, mostly theay are expressive performance data
|
||||
# '''
|
||||
# print("preprocessed tune_in_idx data is being loaded")
|
||||
# tune_in_idx_list = sorted(list(Path(f"dataset/represented_data/tuneidx/tuneidx_{self.__class__.__name__}/{self.encoding_scheme}{self.num_features}").rglob("*.npz")))
|
||||
# if self.debug:
|
||||
# tune_in_idx_list = tune_in_idx_list[:5000]
|
||||
# tune_in_idx_dict = OrderedDict()
|
||||
# len_tunes = OrderedDict()
|
||||
# file_name_list = []
|
||||
# with open("metadata/LakhClean_irregular_tunes.json", "r") as f:
|
||||
# irregular_tunes = json.load(f)
|
||||
# for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
|
||||
# if tune_in_idx_file.stem in irregular_tunes:
|
||||
# continue
|
||||
# tune_in_idx = np.load(tune_in_idx_file)['arr_0']
|
||||
# tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
|
||||
# len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
|
||||
# file_name_list.append(tune_in_idx_file.stem)
|
||||
# print(f"number of loaded tunes: {len(tune_in_idx_dict)}")
|
||||
# return tune_in_idx_dict, len_tunes, file_name_list
|
||||
|
||||
def _get_split_list_from_tune_in_idx(self, ratio, seed):
|
||||
'''
|
||||
As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
|
||||
'''
|
||||
shuffled_tune_names = list(self.tune_in_idx.keys())
|
||||
song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
|
||||
song_dict = {}
|
||||
for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
|
||||
if song not in song_dict:
|
||||
song_dict[song] = []
|
||||
song_dict[song].append(orig_song)
|
||||
unique_song_names = list(song_dict.keys())
|
||||
random.seed(seed)
|
||||
random.shuffle(unique_song_names)
|
||||
num_train = int(len(unique_song_names)*ratio)
|
||||
num_valid = int(len(unique_song_names)*(1-ratio)/2)
|
||||
train_names = []
|
||||
valid_names = []
|
||||
test_names = []
|
||||
for song_name in unique_song_names[:num_train]:
|
||||
train_names.extend(song_dict[song_name])
|
||||
for song_name in unique_song_names[num_train:num_train+num_valid]:
|
||||
valid_names.extend(song_dict[song_name])
|
||||
for song_name in unique_song_names[num_train+num_valid:]:
|
||||
test_names.extend(song_dict[song_name])
|
||||
return train_names, valid_names, test_names, shuffled_tune_names
|
||||
# def _get_split_list_from_tune_in_idx(self, ratio, seed):
|
||||
# '''
|
||||
# As Lakh dataset contains multiple versions of the same song, we split the dataset based on the song name
|
||||
# '''
|
||||
# shuffled_tune_names = list(self.tune_in_idx.keys())
|
||||
# song_names_without_version = [re.sub(r"\.\d+$", "", song) for song in shuffled_tune_names]
|
||||
# song_dict = {}
|
||||
# for song, orig_song in zip(song_names_without_version, shuffled_tune_names):
|
||||
# if song not in song_dict:
|
||||
# song_dict[song] = []
|
||||
# song_dict[song].append(orig_song)
|
||||
# unique_song_names = list(song_dict.keys())
|
||||
# random.seed(seed)
|
||||
# random.shuffle(unique_song_names)
|
||||
# num_train = int(len(unique_song_names)*ratio)
|
||||
# num_valid = int(len(unique_song_names)*(1-ratio)/2)
|
||||
# train_names = []
|
||||
# valid_names = []
|
||||
# test_names = []
|
||||
# for song_name in unique_song_names[:num_train]:
|
||||
# train_names.extend(song_dict[song_name])
|
||||
# for song_name in unique_song_names[num_train:num_train+num_valid]:
|
||||
# valid_names.extend(song_dict[song_name])
|
||||
# for song_name in unique_song_names[num_train+num_valid:]:
|
||||
# test_names.extend(song_dict[song_name])
|
||||
# return train_names, valid_names, test_names, shuffled_tune_names
|
||||
|
||||
class ariamidi(SymbolicMusicDataset):
|
||||
def __init__(self, vocab, encoding_scheme, num_features, debug, aug_type, input_length, first_pred_feature, caption_path=None,
|
||||
@ -788,6 +844,9 @@ class gigamidi(SymbolicMusicDataset):
|
||||
for tune_in_idx_file in tqdm(tune_in_idx_list, total=len(tune_in_idx_list)):
|
||||
if tune_in_idx_file.stem in irregular_tunes:
|
||||
continue
|
||||
if "drums-only" in tune_in_idx_file.stem:
|
||||
print(f"skipping {tune_in_idx_file.stem} as it is a drums-only file")
|
||||
continue
|
||||
tune_in_idx = np.load(tune_in_idx_file)['arr_0']
|
||||
tune_in_idx_dict[tune_in_idx_file.stem] = tune_in_idx
|
||||
len_tunes[tune_in_idx_file.stem] = len(tune_in_idx)
|
||||
|
||||
@ -11,7 +11,7 @@ License: MIT, see the LICENSE file
|
||||
|
||||
__all__ = ['FluidSynth']
|
||||
|
||||
DEFAULT_SOUND_FONT = '/data2/suhongju/research/music-generation/sound_file/CrisisGeneralMidi3.01.sf2'
|
||||
DEFAULT_SOUND_FONT = 'Alex_GM.sf2'
|
||||
DEFAULT_SAMPLE_RATE = 48000
|
||||
DEFAULT_GAIN = 0.05
|
||||
# DEFAULT_SOUND_FONT = "/data2/suhongju/research/music-generation/sound_file/Advent GM 7.sf2"
|
||||
|
||||
@ -2,7 +2,8 @@ defaults:
|
||||
# - nn_params: nb8_embSum_NMT
|
||||
# - nn_params: remi8
|
||||
# - nn_params: nb8_embSum_diff_t2m_150M_finetunning
|
||||
- nn_params: nb8_embSum_diff_t2m_150M_pretrainingv2
|
||||
# - nn_params: nb8_embSum_diff_t2m_600M_pretrainingv2
|
||||
- nn_params: nb8_embSum_diff_t2m_600M_finetunningv2
|
||||
# - nn_params: nb8_embSum_subPararell
|
||||
# - nn_params: nb8_embSum_diff_t2m_150M
|
||||
|
||||
@ -14,7 +15,7 @@ defaults:
|
||||
# - nn_params: remi8_main12_head_16_dim512
|
||||
# - nn_params: nb5_embSum_diff_main12head16dim768_sub3
|
||||
|
||||
dataset: LakhClean # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
||||
dataset: FinetuneDataset # Pop1k7, Pop909, SOD, LakhClean,PretrainingDataset FinetuneDataset
|
||||
captions_path: dataset/midicaps/train_set.json
|
||||
|
||||
# dataset: SymphonyNet_Dataset # Pop1k7, Pop909, SOD, LakhClean
|
||||
@ -30,20 +31,20 @@ tau: 0.5
|
||||
|
||||
train_params:
|
||||
device: cuda
|
||||
batch_size: 3
|
||||
batch_size: 5
|
||||
grad_clip: 1.0
|
||||
num_iter: 300000 # total number of iterations
|
||||
num_cycles_for_inference: 10 # number of cycles for inference, iterations_per_validation_cycle * num_cycles_for_inference
|
||||
num_cycles_for_model_checkpoint: 1 # number of cycles for model checkpoint, iterations_per_validation_cycle * num_cycles_for_model_checkpoint
|
||||
iterations_per_training_cycle: 10 # number of iterations for logging training loss
|
||||
iterations_per_validation_cycle: 5000 # number of iterations for validation process
|
||||
iterations_per_validation_cycle: 3000 # number of iterations for validation process
|
||||
input_length: 3072 # input sequence length3072
|
||||
# you can use focal loss, it it's not used, set focal_gamma to 0
|
||||
focal_alpha: 1
|
||||
focal_gamma: 0
|
||||
# learning rate scheduler: 'cosinelr', 'cosineannealingwarmuprestarts', 'not-using', please check train_utils.py for more details
|
||||
scheduler : cosinelr
|
||||
initial_lr: 0.00005
|
||||
initial_lr: 0.0004
|
||||
decay_step_rate: 0.8 # means it will reach its lowest point at decay_step_rate * total_num_iter
|
||||
num_steps_per_cycle: 20000 # number of steps per cycle for 'cosineannealingwarmuprestarts'
|
||||
warmup_steps: 2000 #number of warmup steps
|
||||
|
||||
@ -5,13 +5,13 @@ model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: DiffusionDecoder
|
||||
model_dropout: 0
|
||||
model_dropout: 0.2
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
main_decoder:
|
||||
dim_model: 768
|
||||
num_layer: 20
|
||||
num_layer: 16
|
||||
num_head: 12
|
||||
sub_decoder:
|
||||
decout_window_size: 1 # 1 means no previous decoding output added
|
||||
|
||||
@ -0,0 +1,19 @@
|
||||
encoding_scheme: nb
|
||||
num_features: 8
|
||||
vocab_name: MusicTokenVocabNB
|
||||
model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewFinetunningDecoder
|
||||
sub_decoder_name: DiffusionDecoder
|
||||
model_dropout: 0
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
main_decoder:
|
||||
dim_model: 1024
|
||||
num_layer: 32
|
||||
num_head: 16
|
||||
sub_decoder:
|
||||
decout_window_size: 1 # 1 means no previous decoding output added
|
||||
num_layer: 1
|
||||
feature_enricher_use: False
|
||||
@ -0,0 +1,19 @@
|
||||
encoding_scheme: nb
|
||||
num_features: 8
|
||||
vocab_name: MusicTokenVocabNB
|
||||
model_name: AmadeusModel
|
||||
input_embedder_name: SummationEmbedder
|
||||
main_decoder_name: XtransformerNewPretrainingDecoder
|
||||
sub_decoder_name: DiffusionDecoder
|
||||
model_dropout: 0
|
||||
input_embedder:
|
||||
num_layer: 1
|
||||
num_head: 8
|
||||
main_decoder:
|
||||
dim_model: 1024
|
||||
num_layer: 32
|
||||
num_head: 16
|
||||
sub_decoder:
|
||||
decout_window_size: 1 # 1 means no previous decoding output added
|
||||
num_layer: 1
|
||||
feature_enricher_use: False
|
||||
@ -729,9 +729,8 @@ class XtransformerNewPretrainingDecoder(nn.Module):
|
||||
rotary_pos_emb = True,
|
||||
ff_swish = True, # set this to True
|
||||
ff_glu = True, # set to true to use for all feedforwards
|
||||
# shift_tokens = 1,
|
||||
# attn_qk_norm = True,
|
||||
# attn_qk_norm_dim_scale = True
|
||||
attn_gate_values = True,
|
||||
attn_qk_norm = True,
|
||||
)
|
||||
# add final dropout
|
||||
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||
@ -758,7 +757,7 @@ class XtransformerNewPretrainingDecoder(nn.Module):
|
||||
|
||||
def _apply_xavier_init(self):
|
||||
for name, param in self.transformer_decoder.named_parameters():
|
||||
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||
if 'to_q' in name or 'to_k' in name or 'to_v' in name and 'to_v_gate' not in name:
|
||||
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||
|
||||
def forward(self, seq, cache=None,train=False,context=None, context_embedding=None):
|
||||
@ -906,6 +905,102 @@ class XtransformerFinetuningDecoder(nn.Module):
|
||||
else:
|
||||
context = context_embedding
|
||||
|
||||
# concatenate context with seq
|
||||
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
|
||||
if cache is not None: # implementing run_one_step in inference
|
||||
if cache.hiddens is None: cache = None
|
||||
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
|
||||
# cut to only return the seq part
|
||||
return hidden_vec, intermediates
|
||||
else:
|
||||
if train:
|
||||
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
|
||||
# cut to only return the seq part
|
||||
hidden_vec = hidden_vec[:, context.shape[1]:, :]
|
||||
return hidden_vec, intermediates
|
||||
else:
|
||||
# cut to only return the seq part
|
||||
hidden_vec = self.transformer_decoder(seq)
|
||||
hidden_vec = hidden_vec[:, context.shape[1]:, :]
|
||||
return hidden_vec
|
||||
|
||||
class XtransformerNewFinetunningDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim:int,
|
||||
depth:int,
|
||||
heads:int,
|
||||
dropout:float
|
||||
):
|
||||
super().__init__()
|
||||
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||
# frozen text encoder
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
if dim != 768:
|
||||
self.text_project = nn.Linear(768, dim) # assuming T5 base hidden size is 768
|
||||
|
||||
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||
self.transformer_decoder = Decoder(
|
||||
dim = dim,
|
||||
depth = depth,
|
||||
heads = heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
attn_flash = True,
|
||||
use_rmsnorm=True,
|
||||
rotary_pos_emb = True,
|
||||
ff_swish = True, # set this to True
|
||||
ff_glu = True, # set to true to use for all feedforwards
|
||||
attn_gate_values = True,
|
||||
attn_qk_norm = True,
|
||||
) # add final dropout
|
||||
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||
self._apply_xavier_init()
|
||||
print('Adding dropout after feedforward layer in x-transformer')
|
||||
self._add_dropout_after_ff(dropout)
|
||||
print('Adding dropout after attention layer in x-transformer')
|
||||
self._add_dropout_after_attn(dropout)
|
||||
|
||||
def _add_dropout_after_attn(self, dropout):
|
||||
for layer in self.transformer_decoder.layers:
|
||||
if 'Attention' in str(type(layer[1])):
|
||||
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||
layer[1].to_out.append(nn.Dropout(dropout))
|
||||
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||
else:
|
||||
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||
|
||||
def _add_dropout_after_ff(self, dropout):
|
||||
for layer in self.transformer_decoder.layers:
|
||||
if 'FeedForward' in str(type(layer[1])):
|
||||
layer[1].ff.append(nn.Dropout(dropout))
|
||||
|
||||
def _apply_xavier_init(self):
|
||||
for name, param in self.transformer_decoder.named_parameters():
|
||||
if 'to_q' in name or 'to_k' in name or 'to_v' in name and 'to_v_gate' not in name:
|
||||
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||
|
||||
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
|
||||
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
|
||||
if context_embedding is None:
|
||||
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
|
||||
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
|
||||
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
|
||||
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
|
||||
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
|
||||
|
||||
context = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
).last_hidden_state
|
||||
else:
|
||||
context = context_embedding
|
||||
if hasattr(self, 'text_project'):
|
||||
context = self.text_project(context)
|
||||
|
||||
# concatenate context with seq
|
||||
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
|
||||
if cache is not None: # implementing run_one_step in inference
|
||||
|
||||
Reference in New Issue
Block a user