228 lines
9.4 KiB
Python
228 lines
9.4 KiB
Python
from math import ceil
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, in_size, out_size, hidden_size, dropout):
|
|
super().__init__()
|
|
self.out_size = out_size
|
|
self.layer = nn.Sequential(
|
|
nn.Linear(in_size, hidden_size),
|
|
nn.Dropout(dropout),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_size, out_size)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.layer(x)
|
|
|
|
class extendedMLP(nn.Module):
|
|
def __init__(self, in_size, out_size, num_layers, hidden_size, dropout):
|
|
super().__init__()
|
|
self.input_size = in_size
|
|
|
|
self.layers = nn.ModuleList()
|
|
if num_layers == 1:
|
|
# Only one layer
|
|
self.layers.append(nn.Linear(in_size, out_size))
|
|
return
|
|
elif num_layers > 1:
|
|
# First layer
|
|
self.layers.append(nn.Linear(in_size, hidden_size))
|
|
self.layers.append(nn.Dropout(dropout))
|
|
self.layers.append(nn.ReLU())
|
|
# Intermediate layers
|
|
if num_layers > 2:
|
|
for _ in range(num_layers - 2): # -2 because we're manually adding the first and last layers
|
|
self.layers.append(nn.Linear(hidden_size, hidden_size))
|
|
self.layers.append(nn.Dropout(dropout))
|
|
self.layers.append(nn.ReLU())
|
|
# Last layer
|
|
self.layers.append(nn.Linear(hidden_size, out_size))
|
|
else:
|
|
raise ValueError("num_layers should be a positive integer")
|
|
|
|
def forward(self, x):
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
return x
|
|
|
|
class multiMLP(nn.Module):
|
|
def __init__(self, in_size, out_size, hidden_size, dropout, pred_order):
|
|
super().__init__()
|
|
self.out_size = out_size
|
|
self.layer = nn.ModuleList([MLP(in_size, out_size, hidden_size, dropout) for _ in pred_order])
|
|
|
|
def forward(self, x, choice):
|
|
'''
|
|
x: B x T x d_model
|
|
choice: token type from self.pred_order (str or list of str)
|
|
'''
|
|
if isinstance(choice, str):
|
|
idx = self.pred_order.index(choice)
|
|
return self.layer[idx](x)
|
|
elif len(choice) > 1 and not isinstance(choice, str):
|
|
raise ValueError("multiMLP doesn't support parallel prediction")
|
|
|
|
class ResidualLayerNormModule(nn.Module):
|
|
def __init__(self, submodule: nn.Module):
|
|
super().__init__()
|
|
self.submodule = submodule
|
|
if submodule.__class__.__name__ == 'MultiheadAttention':
|
|
self.layer_norm = nn.LayerNorm(self.submodule.embed_dim)
|
|
else:
|
|
self.layer_norm = nn.LayerNorm(self.submodule.input_size)
|
|
|
|
def forward_attention(self, q, k, v, attn_mask, type):
|
|
attn_output, _ = self.submodule(q, k, v, attn_mask=attn_mask, need_weights=False, average_attn_weights=False)
|
|
return self.layer_norm(attn_output + q)
|
|
|
|
def forward_mlp(self, x):
|
|
return self.layer_norm(self.submodule(x) + x)
|
|
|
|
class MultiProj_hidden2logit(nn.Module):
|
|
def __init__(self, dim, vocab_sizes):
|
|
super().__init__()
|
|
self.layers = nn.ModuleDict({
|
|
f"layer_{key}": nn.Linear(dim, size) for key, size in vocab_sizes.items()
|
|
})
|
|
|
|
def forward(self, hidden_vec, feature):
|
|
logit = self.layers[f"layer_{feature}"](hidden_vec)
|
|
return logit
|
|
|
|
class MultiProj_catvec2hidden(nn.Module):
|
|
def __init__(self, config, par_pred_keys, seq_pred_keys):
|
|
super().__init__()
|
|
'''
|
|
This class is used in SQstyleEachEmbStrategy
|
|
par_pred_keys: list of independent features(These tokens are predicted in parallel)
|
|
seq_pred_keys: list of sequential features(These tokens are predicted sequentially)
|
|
'''
|
|
net_param = config.nn_params
|
|
self.d_model = net_param.model.d_model
|
|
independent_emb_size = 0
|
|
for key in par_pred_keys:
|
|
independent_emb_size += net_param.emb[key]
|
|
self.layers = nn.ModuleDict({
|
|
'layer_independent': nn.Linear(self.d_model + independent_emb_size, self.d_model),
|
|
**{f"layer_{key}": nn.Linear(self.d_model + net_param.emb[key], self.d_model) for key in seq_pred_keys}
|
|
})
|
|
self.par_pred_keys = par_pred_keys
|
|
self.seq_pred_keys = seq_pred_keys
|
|
self.dropout = nn.Dropout(0.1)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x, choice):
|
|
'''
|
|
x: B x T x (d_model + emb_size)
|
|
choice: key type (str or list of str)
|
|
'''
|
|
if isinstance(choice, str): # single key
|
|
assert choice in self.seq_pred_keys
|
|
output = self.layers[f"layer_{choice}"](x)
|
|
return self.relu(self.dropout(output))
|
|
elif len(choice) > 1 and not isinstance(choice, str): # multiple keys, parallel
|
|
assert choice == self.par_pred_keys # the order of choice should be the same as the order of self.par_pred_keys
|
|
output = self.layers['layer_independent'](x)
|
|
return self.relu(self.dropout(output))
|
|
|
|
def mask_tensor(tensor, mask_rate=0.15):
|
|
# Get the size of the tensor
|
|
batch_size, seq_len, dim = tensor.size()
|
|
# Calculate the total number of elements and the number to mask
|
|
total_elements = batch_size * seq_len
|
|
num_to_mask = int(total_elements * mask_rate)
|
|
# Create a 1D binary mask where 1 indicates that element will be masked.
|
|
# Start by creating a tensor of zeros with length equal to the total number of elements.
|
|
mask = torch.zeros(total_elements).to(tensor.device)
|
|
# Set `num_to_mask` random indices to 1 (masking)
|
|
indices_to_mask = torch.randperm(total_elements)[:num_to_mask]
|
|
mask[indices_to_mask] = 1
|
|
# Reshape the mask to match the original tensor's shape
|
|
mask = mask.reshape(batch_size, seq_len)
|
|
mask = mask.unsqueeze(2) # B x T x 1
|
|
masked_tensor = tensor * (mask == 0).float() # B x T x d_model
|
|
return masked_tensor
|
|
|
|
def generate_causality_mask_on_window(size, window_size):
|
|
mask = torch.zeros((size, size))
|
|
for i in range(size):
|
|
mask[i, i+window_size:] = 1
|
|
return mask.bool()
|
|
|
|
# generate boolean mask, if the value is 1 or true, it means the value is masked
|
|
# considers BOS token and mask margin
|
|
def generate_CA_mask(tgt_len, memory_len, mask_margin=0):
|
|
mask = torch.triu(torch.ones((tgt_len, memory_len)), diagonal=mask_margin+1)
|
|
return mask.bool()
|
|
|
|
# generate boolean mask, if the value is 1 or true, it means the value is masked
|
|
def generate_SA_mask(tgt_len):
|
|
mask = torch.triu(torch.ones((tgt_len, tgt_len)), diagonal=1)
|
|
return mask.bool()
|
|
|
|
def generate_none_causality_mask(tgt_len, memory_len):
|
|
mask = torch.zeros((tgt_len, memory_len))
|
|
return mask.bool()
|
|
|
|
class DecoderLayer(nn.Module):
|
|
def __init__(self, dim, num_heads, dropout):
|
|
super().__init__()
|
|
self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
|
|
self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout))
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, input_dict):
|
|
'''
|
|
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
|
|
'''
|
|
# cross attention
|
|
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross')
|
|
attn_output = self.residual_FF.forward_mlp(attn_output)
|
|
attn_output = self.dropout(attn_output)
|
|
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
|
|
return output_dict
|
|
|
|
class TransformerLayer(nn.Module):
|
|
def __init__(self, dim, num_heads, dropout):
|
|
super().__init__()
|
|
self.self_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
|
|
self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
|
|
self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout))
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, input_dict):
|
|
'''
|
|
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
|
|
'''
|
|
# self attention
|
|
attn_output = self.self_attn_block.forward_attention(input_dict['input_seq'], input_dict['input_seq'], input_dict['input_seq'], input_dict['memory_mask'], type='self')
|
|
|
|
input_dict['input_seq'] = attn_output
|
|
# cross attention
|
|
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], input_dict['memory_mask'], type='cross')
|
|
attn_output = self.residual_FF.forward_mlp(attn_output)
|
|
attn_output = self.dropout(attn_output)
|
|
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory'], 'memory_mask': input_dict['memory_mask']}
|
|
return output_dict
|
|
|
|
class FeatureEnricher(nn.Module):
|
|
def __init__(self, dim, num_heads, dropout):
|
|
super().__init__()
|
|
self.cross_attn_block = ResidualLayerNormModule(nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True))
|
|
self.residual_FF = ResidualLayerNormModule(extendedMLP(in_size=dim, out_size=dim, num_layers=2, hidden_size=2048, dropout=dropout))
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, input_dict):
|
|
'''
|
|
input_dict = {'input_seq': input_seq, 'memory': memory}
|
|
'''
|
|
# cross attention
|
|
attn_output = self.cross_attn_block.forward_attention(input_dict['input_seq'], input_dict['memory'], input_dict['memory'], None, type='feature_enrichment')
|
|
attn_output = self.residual_FF.forward_mlp(attn_output)
|
|
attn_output = self.dropout(attn_output)
|
|
output_dict = {'input_seq': attn_output, 'memory': input_dict['memory']}
|
|
return output_dict |