1013 update
This commit is contained in:
@ -1358,6 +1358,7 @@ class Attention(Module):
|
||||
dim_latent_kv = None,
|
||||
latent_rope_subheads = None,
|
||||
onnxable = False,
|
||||
use_gated_attention = False, # https://arxiv.org/abs/2505.06708
|
||||
attend_sdp_kwargs: dict = dict(
|
||||
enable_flash = True,
|
||||
enable_math = True,
|
||||
@ -1387,6 +1388,7 @@ class Attention(Module):
|
||||
k_dim = dim_head * kv_heads
|
||||
v_dim = value_dim_head * kv_heads
|
||||
out_dim = value_dim_head * heads
|
||||
gated_dim = out_dim
|
||||
|
||||
# determine input dimensions to qkv based on whether intermediate latent q and kv are being used
|
||||
# for eventually supporting multi-latent attention (MLA)
|
||||
@ -1447,7 +1449,8 @@ class Attention(Module):
|
||||
|
||||
self.to_v_gate = None
|
||||
if gate_values:
|
||||
self.to_v_gate = nn.Linear(dim, out_dim)
|
||||
# self.to_v_gate = nn.Linear(dim, out_dim)
|
||||
self.to_v_gate = nn.Linear(dim_kv_input, gated_dim)
|
||||
self.to_v_gate_activation = F.silu if swiglu_values else F.sigmoid
|
||||
nn.init.constant_(self.to_v_gate.weight, 0)
|
||||
nn.init.constant_(self.to_v_gate.bias, 10)
|
||||
|
||||
Reference in New Issue
Block a user