from math import ceil from typing import Optional, Union, Literal from typing_extensions import Unpack import torch from torch import Tensor, nn from torch.nn import functional as F from transformers import Qwen2Config from transformers.activations import get_activation from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS class RMSNorm(nn.Module): def __init__(self, hidden_size, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, x: Tensor) -> Tensor: dtype = x.dtype device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 x = x.float() x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon) x = self.weight * x return x.to(dtype) class FeedForward(nn.Module): def __init__(self, hidden_size: int, intermediate_size: Optional[int] = None, dropout: float = 0., hidden_act: Literal['silu', 'gelu'] = 'silu'): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size or int(hidden_size * 8 / 3) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = get_activation(hidden_act) self.dropout = nn.Dropout(dropout) def forward(self, x: Tensor) -> Tensor: gate, hidden = self.up_proj(x).chunk(2, dim=-1) gate = self.act_fn(gate) return self.down_proj(self.dropout(gate * hidden)) class MultiHeadAttention(nn.Module): def __init__(self, hidden_size: int, head_dim: int = 64, num_attention_heads: Optional[int] = None, attention_dropout: float = 0., bias: bool = False, rms_norm_eps: float = 1e-6, attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa" ): super().__init__() self.num_attention_heads = num_attention_heads or int(hidden_size // head_dim) self.head_dim = head_dim self.scaling = head_dim ** -0.5 self.attention_dropout = attention_dropout self.q_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim, bias=bias) self.k_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim, bias=bias) self.v_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim, bias=bias) self.o_proj = nn.Linear( self.num_attention_heads * self.head_dim, hidden_size, bias=bias) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.attention_interface = ALL_ATTENTION_FUNCTIONS[attn_implementation] # To be compatible with huggingface self.config = Qwen2Config() self.is_causal = False def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], key_states: Optional[Tensor] = None, value_states: Optional[Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] q = self.q_proj(hidden_states) hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_norm(q.view(hidden_shape)).transpose(1, 2) if(key_states is None or value_states is None): k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) key_states = self.k_norm(k.view(hidden_shape)).transpose(1, 2) value_states = v.view(hidden_shape).transpose(1, 2) attn_output, attn_weights = self.attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class DecoderLayerWithCrossAttention(nn.Module): def __init__(self, hidden_size: int, attn_head_dim: int = 64, num_attention_heads: Optional[int] = None, attention_dropout: float = 0., attn_bias: bool = False, mlp_intermediate_size: Optional[int] = None, mlp_dropout: float = 0., mlp_act: Literal['gelu', 'silu'] = 'silu', rms_norm_eps: float = 1e-6, attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa" ): super().__init__() self.hidden_size = hidden_size self.self_attn = MultiHeadAttention( hidden_size, attn_head_dim, num_attention_heads, attention_dropout, attn_bias, rms_norm_eps, attn_implementation) self.cross_attn = MultiHeadAttention( hidden_size, attn_head_dim, num_attention_heads, attention_dropout, attn_bias, rms_norm_eps, attn_implementation) self.mlp = FeedForward( hidden_size, mlp_intermediate_size, mlp_dropout, mlp_act) self.self_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps) self.cross_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps) self.ffn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps) def forward( self, input_dict, ): ''' input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask} ''' hidden_states = input_dict['input_seq'] cross_attn_states = input_dict['memory'] cross_attn_mask = input_dict['memory_mask'] attention_mask = input_dict['memory_mask'] output_attentions = False # Self Attention residual = hidden_states hidden_states = self.self_attn_layernorm(hidden_states) hidden_states, dec_attn_weight = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, ) hidden_states = residual + hidden_states # Cross Attention if(cross_attn_states is not None): residual = hidden_states hidden_states = self.cross_attn_layernorm(hidden_states) cross_attn_shape = (*cross_attn_states.shape[:-1], -1, self.cross_attn.head_dim) key_states = self.cross_attn.k_proj(cross_attn_states) value_states = self.cross_attn.v_proj(cross_attn_states) key_states = self.cross_attn.k_norm(key_states.view(cross_attn_shape)).transpose(1, 2) value_states = value_states.view(cross_attn_shape).transpose(1, 2) hidden_states, cross_attn_weight = self.cross_attn( hidden_states=hidden_states, attention_mask=cross_attn_mask, key_states=key_states, value_states=value_states ) hidden_states = residual + hidden_states # Feed Forward residual = hidden_states hidden_states = self.ffn_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states output_dict = {'input_seq': hidden_states, 'memory': cross_attn_states, 'memory_mask': cross_attn_mask} if(output_attentions): if(cross_attn_states is not None): # return (hidden_states, (dec_attn_weight, cross_attn_weight)) output_dict['dec_attn_weight'] = dec_attn_weight output_dict['cross_attn_weight'] = cross_attn_weight return output_dict else: # return (hidden_states, dec_attn_weight) output_dict['dec_attn_weight'] = dec_attn_weight return output_dict else: return output_dict def sinusoidal_positional_embedding( x: Tensor, n: float = 10000.0) -> Tensor: device, dtype = x.device, x.dtype T = x.shape[-2] D = x.shape[-1] positions = torch.arange(0, T, device=device, dtype=dtype).unsqueeze_(1) embeddings = torch.zeros(T, D, device=device, dtype=dtype) denominators = torch.pow(n, 2*torch.arange(0, D//2, device=device, dtype=dtype)/D) # 10000^(2i/d_model), i is the index of embedding embeddings[:, 0::2] = torch.sin(positions/denominators) # sin(pos/10000^(2i/d_model)) embeddings[:, 1::2] = torch.cos(positions/denominators) # cos(pos/10000^(2i/d_model)) return x + embeddings import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import numpy as np # 创建一个简单的数据集来测试收敛性 class SimpleSeq2SeqDataset(Dataset): def __init__(self, num_samples=1000, seq_len=10, vocab_size=100, hidden_size=256): self.num_samples = num_samples self.seq_len = seq_len self.vocab_size = vocab_size self.hidden_size = hidden_size def __len__(self): return self.num_samples def __getitem__(self, idx): # 模拟编码器输出 (memory) memory = torch.randn(self.seq_len, self.hidden_size) # 模拟解码器输入和目标 decoder_input = torch.randint(0, self.vocab_size, (self.seq_len,)) target = torch.randint(0, self.vocab_size, (self.seq_len,)) # 创建简单的注意力掩码 memory_mask = torch.ones(self.seq_len, self.seq_len) return { 'memory': memory, 'decoder_input': decoder_input, 'target': target, 'memory_mask': memory_mask } # 创建简单的模型 class SimpleDecoderModel(nn.Module): def __init__(self, vocab_size, hidden_size=256, num_layers=3): super(SimpleDecoderModel, self).__init__() self.hidden_size = hidden_size self.vocab_size = vocab_size # 嵌入层 self.embedding = nn.Embedding(vocab_size, hidden_size) # 位置编码 self.pos_encoding = sinusoidal_positional_embedding # 多个解码器层 self.layers = nn.ModuleList([ DecoderLayerWithCrossAttention( hidden_size=hidden_size, attn_head_dim=64, num_attention_heads=2, attention_dropout=0.1, attn_bias=False, mlp_intermediate_size=512, mlp_dropout=0.1, mlp_act='silu', rms_norm_eps=1e-6, attn_implementation="sdpa" ) for _ in range(num_layers) ]) # 输出层 self.output_norm = nn.RMSNorm(hidden_size, eps=1e-6) self.output_proj = nn.Linear(hidden_size, vocab_size) def forward(self, decoder_input, memory, memory_mask): # 嵌入 + 位置编码 x = self.embedding(decoder_input) x = self.pos_encoding(x) # 通过多个解码器层 input_dict = { 'input_seq': x, 'memory': memory, 'memory_mask': memory_mask } for layer in self.layers: output_dict = layer(input_dict) input_dict['input_seq'] = output_dict['input_seq'] # 最终输出 hidden_states = self.output_norm(output_dict['input_seq']) logits = self.output_proj(hidden_states) return logits # 训练函数 def train_decoder_model(): # 超参数 vocab_size = 100 hidden_size = 256 seq_len = 10 batch_size = 32 num_epochs = 10 learning_rate = 0.001 # 设备设置 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 创建数据集和数据加载器 dataset = SimpleSeq2SeqDataset(num_samples=1000, seq_len=seq_len, vocab_size=vocab_size, hidden_size=hidden_size) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # 初始化模型 model = SimpleDecoderModel(vocab_size=vocab_size, hidden_size=hidden_size) model.to(device) # 损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) # 训练记录 train_losses = [] print("开始训练...") model.train() for epoch in range(num_epochs): epoch_loss = 0.0 num_batches = 0 for batch in dataloader: # 移动到设备 memory = batch['memory'].to(device) decoder_input = batch['decoder_input'].long().to(device) target = batch['target'].long().to(device) memory_mask = batch['memory_mask'].to(device) # 前向传播 optimizer.zero_grad() logits = model(decoder_input, memory, memory_mask) # 计算损失 - 将logits和target重塑为(batch_size * seq_len, vocab_size)和(batch_size * seq_len) loss = criterion(logits.reshape(-1, vocab_size), target.reshape(-1)) # 反向传播 loss.backward() optimizer.step() epoch_loss += loss.item() num_batches += 1 avg_loss = epoch_loss / num_batches train_losses.append(avg_loss) print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}') # 早期停止检查:如果损失明显下降,说明模型在收敛 if epoch >= 2 and avg_loss < 4.0: # 交叉熵损失的合理阈值 print("✅ 模型正常收敛!") return True # 检查最终损失 final_loss = train_losses[-1] if final_loss < 5.0: print("✅ 训练完成,模型正常收敛!") return True else: print("❌ 损失下降不明显,可能需要调整模型或超参数") return False # 运行训练 if __name__ == "__main__": # 先测试单个decoder层的前向传播 print("测试DecoderLayerWithCrossAttention前向传播...") # 创建测试数据 batch_size, seq_len, hidden_size = 2, 64, 256 decoder_input = torch.randn(seq_len, hidden_size) memory = torch.randn(seq_len, hidden_size) memory_mask = torch.ones(seq_len, seq_len) # 测试单层 decoder_layer = DecoderLayerWithCrossAttention( hidden_size=hidden_size, attn_head_dim=64, num_attention_heads=4 ) input_dict = { 'input_seq': decoder_input, 'memory': memory, 'memory_mask': memory_mask } try: output = decoder_layer(input_dict) print("✅ DecoderLayer前向传播测试通过") print(f"输入形状: {decoder_input.shape}") print(f"输出形状: {output['input_seq'].shape}") # 运行完整训练 success = train_decoder_model() if success: print("\n🎉 测试完成!DecoderLayerWithCrossAttention能够正常训练和收敛") else: print("\n⚠️ 训练过程中可能存在问题,建议检查模型结构或超参数") except Exception as e: print(f"❌ 前向传播测试失败: {e}")