454 lines
16 KiB
Python
454 lines
16 KiB
Python
from math import ceil
|
||
from typing import Optional, Union, Literal
|
||
from typing_extensions import Unpack
|
||
|
||
import torch
|
||
from torch import Tensor, nn
|
||
from torch.nn import functional as F
|
||
|
||
from transformers import Qwen2Config
|
||
from transformers.activations import get_activation
|
||
from transformers.cache_utils import Cache
|
||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||
|
||
class RMSNorm(nn.Module):
|
||
def __init__(self, hidden_size, eps: float = 1e-6):
|
||
super().__init__()
|
||
|
||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||
self.variance_epsilon = eps
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
dtype = x.dtype
|
||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||
x = x.float()
|
||
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
|
||
x = self.weight * x
|
||
return x.to(dtype)
|
||
|
||
|
||
class FeedForward(nn.Module):
|
||
def __init__(self,
|
||
hidden_size: int,
|
||
intermediate_size: Optional[int] = None,
|
||
dropout: float = 0.,
|
||
hidden_act: Literal['silu', 'gelu'] = 'silu'):
|
||
super().__init__()
|
||
|
||
self.hidden_size = hidden_size
|
||
self.intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
|
||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
|
||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||
self.act_fn = get_activation(hidden_act)
|
||
self.dropout = nn.Dropout(dropout)
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
gate, hidden = self.up_proj(x).chunk(2, dim=-1)
|
||
gate = self.act_fn(gate)
|
||
|
||
return self.down_proj(self.dropout(gate * hidden))
|
||
|
||
|
||
class MultiHeadAttention(nn.Module):
|
||
def __init__(self,
|
||
hidden_size: int,
|
||
head_dim: int = 64,
|
||
num_attention_heads: Optional[int] = None,
|
||
attention_dropout: float = 0.,
|
||
bias: bool = False,
|
||
rms_norm_eps: float = 1e-6,
|
||
attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa"
|
||
):
|
||
super().__init__()
|
||
|
||
self.num_attention_heads = num_attention_heads or int(hidden_size // head_dim)
|
||
self.head_dim = head_dim
|
||
self.scaling = head_dim ** -0.5
|
||
self.attention_dropout = attention_dropout
|
||
|
||
self.q_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
|
||
bias=bias)
|
||
self.k_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
|
||
bias=bias)
|
||
self.v_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
|
||
bias=bias)
|
||
self.o_proj = nn.Linear(
|
||
self.num_attention_heads * self.head_dim, hidden_size,
|
||
bias=bias)
|
||
|
||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||
|
||
self.attention_interface = ALL_ATTENTION_FUNCTIONS[attn_implementation]
|
||
|
||
# To be compatible with huggingface
|
||
self.config = Qwen2Config()
|
||
self.is_causal = False
|
||
|
||
def forward(
|
||
self,
|
||
hidden_states: torch.Tensor,
|
||
attention_mask: Optional[torch.Tensor],
|
||
key_states: Optional[Tensor] = None,
|
||
value_states: Optional[Tensor] = None,
|
||
**kwargs: Unpack[FlashAttentionKwargs],
|
||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||
input_shape = hidden_states.shape[:-1]
|
||
|
||
q = self.q_proj(hidden_states)
|
||
|
||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||
query_states = self.q_norm(q.view(hidden_shape)).transpose(1, 2)
|
||
|
||
if(key_states is None or value_states is None):
|
||
k = self.k_proj(hidden_states)
|
||
v = self.v_proj(hidden_states)
|
||
|
||
key_states = self.k_norm(k.view(hidden_shape)).transpose(1, 2)
|
||
value_states = v.view(hidden_shape).transpose(1, 2)
|
||
|
||
attn_output, attn_weights = self.attention_interface(
|
||
self,
|
||
query_states,
|
||
key_states,
|
||
value_states,
|
||
attention_mask,
|
||
dropout=0.0 if not self.training else self.attention_dropout,
|
||
**kwargs,
|
||
)
|
||
|
||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||
attn_output = self.o_proj(attn_output)
|
||
|
||
return attn_output, attn_weights
|
||
|
||
|
||
class DecoderLayerWithCrossAttention(nn.Module):
|
||
def __init__(self,
|
||
hidden_size: int,
|
||
|
||
attn_head_dim: int = 64,
|
||
num_attention_heads: Optional[int] = None,
|
||
attention_dropout: float = 0.,
|
||
attn_bias: bool = False,
|
||
|
||
mlp_intermediate_size: Optional[int] = None,
|
||
mlp_dropout: float = 0.,
|
||
mlp_act: Literal['gelu', 'silu'] = 'silu',
|
||
|
||
rms_norm_eps: float = 1e-6,
|
||
attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa"
|
||
):
|
||
super().__init__()
|
||
self.hidden_size = hidden_size
|
||
|
||
self.self_attn = MultiHeadAttention(
|
||
hidden_size,
|
||
attn_head_dim,
|
||
num_attention_heads,
|
||
attention_dropout,
|
||
attn_bias,
|
||
rms_norm_eps,
|
||
attn_implementation)
|
||
|
||
self.cross_attn = MultiHeadAttention(
|
||
hidden_size,
|
||
attn_head_dim,
|
||
num_attention_heads,
|
||
attention_dropout,
|
||
attn_bias,
|
||
rms_norm_eps,
|
||
attn_implementation)
|
||
|
||
self.mlp = FeedForward(
|
||
hidden_size,
|
||
mlp_intermediate_size,
|
||
mlp_dropout,
|
||
mlp_act)
|
||
|
||
self.self_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
|
||
self.cross_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
|
||
self.ffn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
|
||
|
||
def forward(
|
||
self,
|
||
input_dict,
|
||
):
|
||
'''
|
||
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
|
||
'''
|
||
hidden_states = input_dict['input_seq']
|
||
cross_attn_states = input_dict['memory']
|
||
cross_attn_mask = input_dict['memory_mask']
|
||
attention_mask = input_dict['memory_mask']
|
||
output_attentions = False
|
||
# Self Attention
|
||
residual = hidden_states
|
||
hidden_states = self.self_attn_layernorm(hidden_states)
|
||
hidden_states, dec_attn_weight = self.self_attn(
|
||
hidden_states=hidden_states,
|
||
attention_mask=attention_mask,
|
||
)
|
||
hidden_states = residual + hidden_states
|
||
|
||
# Cross Attention
|
||
if(cross_attn_states is not None):
|
||
residual = hidden_states
|
||
hidden_states = self.cross_attn_layernorm(hidden_states)
|
||
cross_attn_shape = (*cross_attn_states.shape[:-1], -1, self.cross_attn.head_dim)
|
||
key_states = self.cross_attn.k_proj(cross_attn_states)
|
||
value_states = self.cross_attn.v_proj(cross_attn_states)
|
||
|
||
key_states = self.cross_attn.k_norm(key_states.view(cross_attn_shape)).transpose(1, 2)
|
||
value_states = value_states.view(cross_attn_shape).transpose(1, 2)
|
||
hidden_states, cross_attn_weight = self.cross_attn(
|
||
hidden_states=hidden_states,
|
||
attention_mask=cross_attn_mask,
|
||
key_states=key_states,
|
||
value_states=value_states
|
||
)
|
||
hidden_states = residual + hidden_states
|
||
|
||
# Feed Forward
|
||
residual = hidden_states
|
||
hidden_states = self.ffn_layernorm(hidden_states)
|
||
hidden_states = self.mlp(hidden_states)
|
||
hidden_states = residual + hidden_states
|
||
output_dict = {'input_seq': hidden_states, 'memory': cross_attn_states, 'memory_mask': cross_attn_mask}
|
||
if(output_attentions):
|
||
if(cross_attn_states is not None):
|
||
# return (hidden_states, (dec_attn_weight, cross_attn_weight))
|
||
output_dict['dec_attn_weight'] = dec_attn_weight
|
||
output_dict['cross_attn_weight'] = cross_attn_weight
|
||
return output_dict
|
||
else:
|
||
# return (hidden_states, dec_attn_weight)
|
||
output_dict['dec_attn_weight'] = dec_attn_weight
|
||
return output_dict
|
||
else:
|
||
return output_dict
|
||
|
||
def sinusoidal_positional_embedding(
|
||
x: Tensor,
|
||
n: float = 10000.0) -> Tensor:
|
||
device, dtype = x.device, x.dtype
|
||
T = x.shape[-2]
|
||
D = x.shape[-1]
|
||
|
||
positions = torch.arange(0, T, device=device, dtype=dtype).unsqueeze_(1)
|
||
embeddings = torch.zeros(T, D, device=device, dtype=dtype)
|
||
|
||
denominators = torch.pow(n, 2*torch.arange(0, D//2, device=device, dtype=dtype)/D) # 10000^(2i/d_model), i is the index of embedding
|
||
embeddings[:, 0::2] = torch.sin(positions/denominators) # sin(pos/10000^(2i/d_model))
|
||
embeddings[:, 1::2] = torch.cos(positions/denominators) # cos(pos/10000^(2i/d_model))
|
||
|
||
return x + embeddings
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
from torch.utils.data import Dataset, DataLoader
|
||
import numpy as np
|
||
|
||
# 创建一个简单的数据集来测试收敛性
|
||
class SimpleSeq2SeqDataset(Dataset):
|
||
def __init__(self, num_samples=1000, seq_len=10, vocab_size=100, hidden_size=256):
|
||
self.num_samples = num_samples
|
||
self.seq_len = seq_len
|
||
self.vocab_size = vocab_size
|
||
self.hidden_size = hidden_size
|
||
|
||
def __len__(self):
|
||
return self.num_samples
|
||
|
||
def __getitem__(self, idx):
|
||
# 模拟编码器输出 (memory)
|
||
memory = torch.randn(self.seq_len, self.hidden_size)
|
||
|
||
# 模拟解码器输入和目标
|
||
decoder_input = torch.randint(0, self.vocab_size, (self.seq_len,))
|
||
target = torch.randint(0, self.vocab_size, (self.seq_len,))
|
||
|
||
# 创建简单的注意力掩码
|
||
memory_mask = torch.ones(self.seq_len, self.seq_len)
|
||
|
||
return {
|
||
'memory': memory,
|
||
'decoder_input': decoder_input,
|
||
'target': target,
|
||
'memory_mask': memory_mask
|
||
}
|
||
|
||
# 创建简单的模型
|
||
class SimpleDecoderModel(nn.Module):
|
||
def __init__(self, vocab_size, hidden_size=256, num_layers=3):
|
||
super(SimpleDecoderModel, self).__init__()
|
||
self.hidden_size = hidden_size
|
||
self.vocab_size = vocab_size
|
||
|
||
# 嵌入层
|
||
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
||
|
||
# 位置编码
|
||
self.pos_encoding = sinusoidal_positional_embedding
|
||
|
||
# 多个解码器层
|
||
self.layers = nn.ModuleList([
|
||
DecoderLayerWithCrossAttention(
|
||
hidden_size=hidden_size,
|
||
attn_head_dim=64,
|
||
num_attention_heads=2,
|
||
attention_dropout=0.1,
|
||
attn_bias=False,
|
||
mlp_intermediate_size=512,
|
||
mlp_dropout=0.1,
|
||
mlp_act='silu',
|
||
rms_norm_eps=1e-6,
|
||
attn_implementation="sdpa"
|
||
) for _ in range(num_layers)
|
||
])
|
||
|
||
# 输出层
|
||
self.output_norm = nn.RMSNorm(hidden_size, eps=1e-6)
|
||
self.output_proj = nn.Linear(hidden_size, vocab_size)
|
||
|
||
def forward(self, decoder_input, memory, memory_mask):
|
||
# 嵌入 + 位置编码
|
||
x = self.embedding(decoder_input)
|
||
x = self.pos_encoding(x)
|
||
|
||
# 通过多个解码器层
|
||
input_dict = {
|
||
'input_seq': x,
|
||
'memory': memory,
|
||
'memory_mask': memory_mask
|
||
}
|
||
|
||
for layer in self.layers:
|
||
output_dict = layer(input_dict)
|
||
input_dict['input_seq'] = output_dict['input_seq']
|
||
|
||
# 最终输出
|
||
hidden_states = self.output_norm(output_dict['input_seq'])
|
||
logits = self.output_proj(hidden_states)
|
||
|
||
return logits
|
||
|
||
# 训练函数
|
||
def train_decoder_model():
|
||
# 超参数
|
||
vocab_size = 100
|
||
hidden_size = 256
|
||
seq_len = 10
|
||
batch_size = 32
|
||
num_epochs = 10
|
||
learning_rate = 0.001
|
||
|
||
# 设备设置
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"使用设备: {device}")
|
||
|
||
# 创建数据集和数据加载器
|
||
dataset = SimpleSeq2SeqDataset(num_samples=1000, seq_len=seq_len,
|
||
vocab_size=vocab_size, hidden_size=hidden_size)
|
||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||
|
||
# 初始化模型
|
||
model = SimpleDecoderModel(vocab_size=vocab_size, hidden_size=hidden_size)
|
||
model.to(device)
|
||
|
||
# 损失函数和优化器
|
||
criterion = nn.CrossEntropyLoss()
|
||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||
|
||
# 训练记录
|
||
train_losses = []
|
||
|
||
print("开始训练...")
|
||
model.train()
|
||
|
||
for epoch in range(num_epochs):
|
||
epoch_loss = 0.0
|
||
num_batches = 0
|
||
|
||
for batch in dataloader:
|
||
# 移动到设备
|
||
memory = batch['memory'].to(device)
|
||
decoder_input = batch['decoder_input'].long().to(device)
|
||
target = batch['target'].long().to(device)
|
||
memory_mask = batch['memory_mask'].to(device)
|
||
|
||
# 前向传播
|
||
optimizer.zero_grad()
|
||
logits = model(decoder_input, memory, memory_mask)
|
||
|
||
# 计算损失 - 将logits和target重塑为(batch_size * seq_len, vocab_size)和(batch_size * seq_len)
|
||
loss = criterion(logits.reshape(-1, vocab_size), target.reshape(-1))
|
||
|
||
# 反向传播
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
epoch_loss += loss.item()
|
||
num_batches += 1
|
||
|
||
avg_loss = epoch_loss / num_batches
|
||
train_losses.append(avg_loss)
|
||
|
||
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
|
||
|
||
# 早期停止检查:如果损失明显下降,说明模型在收敛
|
||
if epoch >= 2 and avg_loss < 4.0: # 交叉熵损失的合理阈值
|
||
print("✅ 模型正常收敛!")
|
||
return True
|
||
|
||
# 检查最终损失
|
||
final_loss = train_losses[-1]
|
||
if final_loss < 5.0:
|
||
print("✅ 训练完成,模型正常收敛!")
|
||
return True
|
||
else:
|
||
print("❌ 损失下降不明显,可能需要调整模型或超参数")
|
||
return False
|
||
|
||
# 运行训练
|
||
if __name__ == "__main__":
|
||
# 先测试单个decoder层的前向传播
|
||
print("测试DecoderLayerWithCrossAttention前向传播...")
|
||
|
||
# 创建测试数据
|
||
batch_size, seq_len, hidden_size = 2, 64, 256
|
||
decoder_input = torch.randn(seq_len, hidden_size)
|
||
memory = torch.randn(seq_len, hidden_size)
|
||
memory_mask = torch.ones(seq_len, seq_len)
|
||
|
||
# 测试单层
|
||
decoder_layer = DecoderLayerWithCrossAttention(
|
||
hidden_size=hidden_size,
|
||
attn_head_dim=64,
|
||
num_attention_heads=4
|
||
)
|
||
|
||
input_dict = {
|
||
'input_seq': decoder_input,
|
||
'memory': memory,
|
||
'memory_mask': memory_mask
|
||
}
|
||
|
||
try:
|
||
output = decoder_layer(input_dict)
|
||
print("✅ DecoderLayer前向传播测试通过")
|
||
print(f"输入形状: {decoder_input.shape}")
|
||
print(f"输出形状: {output['input_seq'].shape}")
|
||
|
||
# 运行完整训练
|
||
success = train_decoder_model()
|
||
if success:
|
||
print("\n🎉 测试完成!DecoderLayerWithCrossAttention能够正常训练和收敛")
|
||
else:
|
||
print("\n⚠️ 训练过程中可能存在问题,建议检查模型结构或超参数")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 前向传播测试失败: {e}") |