1021 add flexable attr control

This commit is contained in:
FelixChan
2025-10-21 15:27:03 +08:00
parent d6b68ef90b
commit b493ede479
15 changed files with 400 additions and 394 deletions

View File

@ -111,6 +111,23 @@ class MultiEmbedding(nn.Module):
def get_emb_by_key(self, key, token):
layer_idx = self.feature_list.index(key)
return self.layers[layer_idx](token)
def get_token_by_emb(self, key, token_emb):
'''
token_emb: B x emb_size
'''
layer_idx = self.feature_list.index(key)
embedding_layer = self.layers[layer_idx] # nn.Embedding
# compute cosine similarity between token_emb and embedding weights
emb_weights = embedding_layer.weight # vocab_size x emb_size
cos_sim = torch.nn.functional.cosine_similarity(
token_emb.unsqueeze(1), # B x 1 x emb_size
emb_weights.unsqueeze(0), # 1 x vocab_size x emb_size
dim=-1
) # B x vocab_size
# get the index of the most similar embedding
token_idx = torch.argmax(cos_sim, dim=-1) # B
return token_idx
class SummationEmbedder(MultiEmbedding):
def __init__(