1013 update

This commit is contained in:
FelixChan
2025-10-13 17:56:36 +08:00
parent d077e3210e
commit d6b68ef90b
17 changed files with 815 additions and 70 deletions

View File

@ -1358,6 +1358,7 @@ class Attention(Module):
dim_latent_kv = None,
latent_rope_subheads = None,
onnxable = False,
use_gated_attention = False, # https://arxiv.org/abs/2505.06708
attend_sdp_kwargs: dict = dict(
enable_flash = True,
enable_math = True,
@ -1387,6 +1388,7 @@ class Attention(Module):
k_dim = dim_head * kv_heads
v_dim = value_dim_head * kv_heads
out_dim = value_dim_head * heads
gated_dim = out_dim
# determine input dimensions to qkv based on whether intermediate latent q and kv are being used
# for eventually supporting multi-latent attention (MLA)
@ -1447,7 +1449,8 @@ class Attention(Module):
self.to_v_gate = None
if gate_values:
self.to_v_gate = nn.Linear(dim, out_dim)
# self.to_v_gate = nn.Linear(dim, out_dim)
self.to_v_gate = nn.Linear(dim_kv_input, gated_dim)
self.to_v_gate_activation = F.silu if swiglu_values else F.sigmoid
nn.init.constant_(self.to_v_gate.weight, 0)
nn.init.constant_(self.to_v_gate.bias, 10)