add gitignore
This commit is contained in:
@ -389,85 +389,6 @@ class XtransformerCrossAttendDecoder(nn.Module):
|
||||
else:
|
||||
return self.transformer_decoder(seq, context=context)
|
||||
|
||||
class XtransformerLargeCrossAttendDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim:int,
|
||||
depth:int,
|
||||
heads:int,
|
||||
dropout:float
|
||||
):
|
||||
super().__init__()
|
||||
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large')
|
||||
# frozen text encoder
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||
self.transformer_decoder = Decoder(
|
||||
dim = dim,
|
||||
depth = depth,
|
||||
heads = heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
attn_flash = True,
|
||||
cross_attend = True,
|
||||
only_cross = False)
|
||||
# add final dropout
|
||||
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||
self._apply_xavier_init()
|
||||
print('Adding dropout after feedforward layer in x-transformer')
|
||||
self._add_dropout_after_ff(dropout)
|
||||
print('Adding dropout after attention layer in x-transformer')
|
||||
self._add_dropout_after_attn(dropout)
|
||||
|
||||
def _add_dropout_after_attn(self, dropout):
|
||||
for layer in self.transformer_decoder.layers:
|
||||
if 'Attention' in str(type(layer[1])):
|
||||
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||
layer[1].to_out.append(nn.Dropout(dropout))
|
||||
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||
else:
|
||||
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||
|
||||
def _add_dropout_after_ff(self, dropout):
|
||||
for layer in self.transformer_decoder.layers:
|
||||
if 'FeedForward' in str(type(layer[1])):
|
||||
layer[1].ff.append(nn.Dropout(dropout))
|
||||
|
||||
def _apply_xavier_init(self):
|
||||
for name, param in self.transformer_decoder.named_parameters():
|
||||
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||
|
||||
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
|
||||
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
|
||||
if context_embedding is None:
|
||||
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
|
||||
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
|
||||
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
|
||||
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
|
||||
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
|
||||
|
||||
context = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask
|
||||
).last_hidden_state
|
||||
else:
|
||||
context = context_embedding
|
||||
|
||||
if cache is not None: # implementing run_one_step in inference
|
||||
if cache.hiddens is None: cache = None
|
||||
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True, context=context)
|
||||
return hidden_vec, intermediates
|
||||
else:
|
||||
if train:
|
||||
hidden_vec, intermediates = self.transformer_decoder(seq, context=context, return_hiddens=True)
|
||||
return hidden_vec, intermediates
|
||||
else:
|
||||
return self.transformer_decoder(seq, context=context)
|
||||
|
||||
class NewCrossAttendDecoder(nn.Module):
|
||||
def __init__(
|
||||
@ -638,6 +559,75 @@ class NewCrossAttendwithRoPEDecoder(nn.Module):
|
||||
else:
|
||||
return self.transformer_decoder(seq, context=context)
|
||||
|
||||
class RoPEDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim:int,
|
||||
depth:int,
|
||||
heads:int,
|
||||
dropout:float
|
||||
):
|
||||
super().__init__()
|
||||
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||
# self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||
# frozen text encoder
|
||||
|
||||
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||
self.transformer_decoder = Decoder(
|
||||
dim = dim,
|
||||
depth = depth,
|
||||
heads = heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
attn_flash = True,
|
||||
# cross_attend = True,
|
||||
only_cross = False,
|
||||
use_rmsnorm=True,
|
||||
rotary_pos_emb = True,
|
||||
ff_swish = True, # set this to True
|
||||
ff_glu = True, # set to true to use for all feedforwards
|
||||
)
|
||||
# add final dropout
|
||||
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||
self._apply_xavier_init()
|
||||
print('Adding dropout after feedforward layer in x-transformer')
|
||||
self._add_dropout_after_ff(dropout)
|
||||
print('Adding dropout after attention layer in x-transformer')
|
||||
self._add_dropout_after_attn(dropout)
|
||||
|
||||
def _add_dropout_after_attn(self, dropout):
|
||||
for layer in self.transformer_decoder.layers:
|
||||
if 'Attention' in str(type(layer[1])):
|
||||
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||
layer[1].to_out.append(nn.Dropout(dropout))
|
||||
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||
else:
|
||||
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||
|
||||
def _add_dropout_after_ff(self, dropout):
|
||||
for layer in self.transformer_decoder.layers:
|
||||
if 'FeedForward' in str(type(layer[1])):
|
||||
layer[1].ff.append(nn.Dropout(dropout))
|
||||
|
||||
def _apply_xavier_init(self):
|
||||
for name, param in self.transformer_decoder.named_parameters():
|
||||
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||
|
||||
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
|
||||
if cache is not None: # implementing run_one_step in inference
|
||||
if cache.hiddens is None: cache = None
|
||||
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
|
||||
return hidden_vec, intermediates
|
||||
else:
|
||||
if train:
|
||||
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
|
||||
return hidden_vec, intermediates
|
||||
else:
|
||||
return self.transformer_decoder(seq)
|
||||
|
||||
|
||||
class XtransformerPrefixDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -711,7 +701,80 @@ class XtransformerPrefixDecoder(nn.Module):
|
||||
return hidden_vec, intermediates
|
||||
else:
|
||||
return self.transformer_decoder(seq)
|
||||
|
||||
class XtransformerNewPretrainingDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim:int,
|
||||
depth:int,
|
||||
heads:int,
|
||||
dropout:float
|
||||
):
|
||||
super().__init__()
|
||||
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
|
||||
# frozen text encoder
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||
self.transformer_decoder = Decoder(
|
||||
dim = dim,
|
||||
depth = depth,
|
||||
heads = heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
attn_flash = True,
|
||||
use_rmsnorm=True,
|
||||
rotary_pos_emb = True,
|
||||
ff_swish = True, # set this to True
|
||||
ff_glu = True, # set to true to use for all feedforwards
|
||||
# shift_tokens = 1,
|
||||
# attn_qk_norm = True,
|
||||
# attn_qk_norm_dim_scale = True
|
||||
)
|
||||
# add final dropout
|
||||
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||
self._apply_xavier_init()
|
||||
print('Adding dropout after feedforward layer in x-transformer')
|
||||
self._add_dropout_after_ff(dropout)
|
||||
print('Adding dropout after attention layer in x-transformer')
|
||||
self._add_dropout_after_attn(dropout)
|
||||
|
||||
def _add_dropout_after_attn(self, dropout):
|
||||
for layer in self.transformer_decoder.layers:
|
||||
if 'Attention' in str(type(layer[1])):
|
||||
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||
layer[1].to_out.append(nn.Dropout(dropout))
|
||||
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||
else:
|
||||
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||
|
||||
def _add_dropout_after_ff(self, dropout):
|
||||
for layer in self.transformer_decoder.layers:
|
||||
if 'FeedForward' in str(type(layer[1])):
|
||||
layer[1].ff.append(nn.Dropout(dropout))
|
||||
|
||||
def _apply_xavier_init(self):
|
||||
for name, param in self.transformer_decoder.named_parameters():
|
||||
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||
|
||||
def forward(self, seq, cache=None,train=False,context=None, context_embedding=None):
|
||||
|
||||
if cache is not None: # implementing run_one_step in inference
|
||||
if cache.hiddens is None: cache = None
|
||||
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
|
||||
return hidden_vec, intermediates
|
||||
else:
|
||||
if train:
|
||||
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
|
||||
return hidden_vec, intermediates
|
||||
else:
|
||||
return self.transformer_decoder(seq)
|
||||
|
||||
|
||||
class XtransformerPretrainingDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -827,92 +890,6 @@ class XtransformerFinetuningDecoder(nn.Module):
|
||||
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||
|
||||
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
|
||||
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
|
||||
if context_embedding is None:
|
||||
input_ids = context['input_ids'].squeeze(1) if context['input_ids'].ndim == 3 else context['input_ids']
|
||||
attention_mask = context['attention_mask'].squeeze(1) if context['attention_mask'].ndim == 3 else context['attention_mask']
|
||||
assert input_ids is not None, 'input_ids should be provided for prefix decoder'
|
||||
assert attention_mask is not None, 'attention_mask should be provided for prefix decoder'
|
||||
assert input_ids.device == self.text_encoder.device, 'input_ids should be on the same device as text_encoder'
|
||||
|
||||
context = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
).last_hidden_state
|
||||
else:
|
||||
context = context_embedding
|
||||
|
||||
# concatenate context with seq
|
||||
seq = torch.cat([context, seq], dim=1) # B x (T+context_length) x emb_size
|
||||
if cache is not None: # implementing run_one_step in inference
|
||||
if cache.hiddens is None: cache = None
|
||||
hidden_vec, intermediates = self.transformer_decoder(seq, cache=cache, return_hiddens=True)
|
||||
# cut to only return the seq part
|
||||
return hidden_vec, intermediates
|
||||
else:
|
||||
if train:
|
||||
hidden_vec, intermediates = self.transformer_decoder(seq, return_hiddens=True)
|
||||
# cut to only return the seq part
|
||||
hidden_vec = hidden_vec[:, context.shape[1]:, :]
|
||||
return hidden_vec, intermediates
|
||||
else:
|
||||
# cut to only return the seq part
|
||||
hidden_vec = self.transformer_decoder(seq)
|
||||
hidden_vec = hidden_vec[:, context.shape[1]:, :]
|
||||
return hidden_vec
|
||||
|
||||
class XtransformerLargeFinetuningDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim:int,
|
||||
depth:int,
|
||||
heads:int,
|
||||
dropout:float
|
||||
):
|
||||
super().__init__()
|
||||
self._make_decoder_layer(dim, depth, heads, dropout)
|
||||
self.text_encoder = T5EncoderModel.from_pretrained('google/flan-t5-large')
|
||||
# frozen text encoder
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _make_decoder_layer(self, dim, depth, heads, dropout):
|
||||
self.transformer_decoder = Decoder(
|
||||
dim = dim,
|
||||
depth = depth,
|
||||
heads = heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
attn_flash = True)
|
||||
# add final dropout
|
||||
print('Applying Xavier Uniform Init to x-transformer following torch.Transformer')
|
||||
self._apply_xavier_init()
|
||||
print('Adding dropout after feedforward layer in x-transformer')
|
||||
self._add_dropout_after_ff(dropout)
|
||||
print('Adding dropout after attention layer in x-transformer')
|
||||
self._add_dropout_after_attn(dropout)
|
||||
|
||||
def _add_dropout_after_attn(self, dropout):
|
||||
for layer in self.transformer_decoder.layers:
|
||||
if 'Attention' in str(type(layer[1])):
|
||||
if isinstance(layer[1].to_out, nn.Sequential): # if GLU
|
||||
layer[1].to_out.append(nn.Dropout(dropout))
|
||||
elif isinstance(layer[1].to_out, nn.Linear): # if simple linear
|
||||
layer[1].to_out = nn.Sequential(layer[1].to_out, nn.Dropout(dropout))
|
||||
else:
|
||||
raise ValueError('to_out should be either nn.Sequential or nn.Linear')
|
||||
|
||||
def _add_dropout_after_ff(self, dropout):
|
||||
for layer in self.transformer_decoder.layers:
|
||||
if 'FeedForward' in str(type(layer[1])):
|
||||
layer[1].ff.append(nn.Dropout(dropout))
|
||||
|
||||
def _apply_xavier_init(self):
|
||||
for name, param in self.transformer_decoder.named_parameters():
|
||||
if 'to_q' in name or 'to_k' in name or 'to_v' in name:
|
||||
torch.nn.init.xavier_uniform_(param, gain=0.5**0.5)
|
||||
|
||||
def forward(self, seq, cache=None,train=False,context=None,context_embedding=None):
|
||||
assert context is not None or context_embedding is not None, 'context or context_embedding should be provided for prefix decoder'
|
||||
if context_embedding is None:
|
||||
|
||||
Reference in New Issue
Block a user