1021 add flexable attr control
This commit is contained in:
@ -111,6 +111,23 @@ class MultiEmbedding(nn.Module):
|
||||
def get_emb_by_key(self, key, token):
|
||||
layer_idx = self.feature_list.index(key)
|
||||
return self.layers[layer_idx](token)
|
||||
|
||||
def get_token_by_emb(self, key, token_emb):
|
||||
'''
|
||||
token_emb: B x emb_size
|
||||
'''
|
||||
layer_idx = self.feature_list.index(key)
|
||||
embedding_layer = self.layers[layer_idx] # nn.Embedding
|
||||
# compute cosine similarity between token_emb and embedding weights
|
||||
emb_weights = embedding_layer.weight # vocab_size x emb_size
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
token_emb.unsqueeze(1), # B x 1 x emb_size
|
||||
emb_weights.unsqueeze(0), # 1 x vocab_size x emb_size
|
||||
dim=-1
|
||||
) # B x vocab_size
|
||||
# get the index of the most similar embedding
|
||||
token_idx = torch.argmax(cos_sim, dim=-1) # B
|
||||
return token_idx
|
||||
|
||||
class SummationEmbedder(MultiEmbedding):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user