1127 update to latest

This commit is contained in:
FelixChan
2025-11-27 15:44:17 +08:00
parent e16c84aab2
commit a34d39430e
153 changed files with 25705 additions and 53 deletions

454
Amadeus/toy_train.py Normal file
View File

@ -0,0 +1,454 @@
from math import ceil
from typing import Optional, Union, Literal
from typing_extensions import Unpack
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from transformers import Qwen2Config
from transformers.activations import get_activation
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, x: Tensor) -> Tensor:
dtype = x.dtype
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
x = self.weight * x
return x.to(dtype)
class FeedForward(nn.Module):
def __init__(self,
hidden_size: int,
intermediate_size: Optional[int] = None,
dropout: float = 0.,
hidden_act: Literal['silu', 'gelu'] = 'silu'):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = get_activation(hidden_act)
self.dropout = nn.Dropout(dropout)
def forward(self, x: Tensor) -> Tensor:
gate, hidden = self.up_proj(x).chunk(2, dim=-1)
gate = self.act_fn(gate)
return self.down_proj(self.dropout(gate * hidden))
class MultiHeadAttention(nn.Module):
def __init__(self,
hidden_size: int,
head_dim: int = 64,
num_attention_heads: Optional[int] = None,
attention_dropout: float = 0.,
bias: bool = False,
rms_norm_eps: float = 1e-6,
attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa"
):
super().__init__()
self.num_attention_heads = num_attention_heads or int(hidden_size // head_dim)
self.head_dim = head_dim
self.scaling = head_dim ** -0.5
self.attention_dropout = attention_dropout
self.q_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
bias=bias)
self.k_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
bias=bias)
self.v_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
bias=bias)
self.o_proj = nn.Linear(
self.num_attention_heads * self.head_dim, hidden_size,
bias=bias)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.attention_interface = ALL_ATTENTION_FUNCTIONS[attn_implementation]
# To be compatible with huggingface
self.config = Qwen2Config()
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
key_states: Optional[Tensor] = None,
value_states: Optional[Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
q = self.q_proj(hidden_states)
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(q.view(hidden_shape)).transpose(1, 2)
if(key_states is None or value_states is None):
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
key_states = self.k_norm(k.view(hidden_shape)).transpose(1, 2)
value_states = v.view(hidden_shape).transpose(1, 2)
attn_output, attn_weights = self.attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class DecoderLayerWithCrossAttention(nn.Module):
def __init__(self,
hidden_size: int,
attn_head_dim: int = 64,
num_attention_heads: Optional[int] = None,
attention_dropout: float = 0.,
attn_bias: bool = False,
mlp_intermediate_size: Optional[int] = None,
mlp_dropout: float = 0.,
mlp_act: Literal['gelu', 'silu'] = 'silu',
rms_norm_eps: float = 1e-6,
attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa"
):
super().__init__()
self.hidden_size = hidden_size
self.self_attn = MultiHeadAttention(
hidden_size,
attn_head_dim,
num_attention_heads,
attention_dropout,
attn_bias,
rms_norm_eps,
attn_implementation)
self.cross_attn = MultiHeadAttention(
hidden_size,
attn_head_dim,
num_attention_heads,
attention_dropout,
attn_bias,
rms_norm_eps,
attn_implementation)
self.mlp = FeedForward(
hidden_size,
mlp_intermediate_size,
mlp_dropout,
mlp_act)
self.self_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
self.cross_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
self.ffn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
def forward(
self,
input_dict,
):
'''
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
'''
hidden_states = input_dict['input_seq']
cross_attn_states = input_dict['memory']
cross_attn_mask = input_dict['memory_mask']
attention_mask = input_dict['memory_mask']
output_attentions = False
# Self Attention
residual = hidden_states
hidden_states = self.self_attn_layernorm(hidden_states)
hidden_states, dec_attn_weight = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
hidden_states = residual + hidden_states
# Cross Attention
if(cross_attn_states is not None):
residual = hidden_states
hidden_states = self.cross_attn_layernorm(hidden_states)
cross_attn_shape = (*cross_attn_states.shape[:-1], -1, self.cross_attn.head_dim)
key_states = self.cross_attn.k_proj(cross_attn_states)
value_states = self.cross_attn.v_proj(cross_attn_states)
key_states = self.cross_attn.k_norm(key_states.view(cross_attn_shape)).transpose(1, 2)
value_states = value_states.view(cross_attn_shape).transpose(1, 2)
hidden_states, cross_attn_weight = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attn_mask,
key_states=key_states,
value_states=value_states
)
hidden_states = residual + hidden_states
# Feed Forward
residual = hidden_states
hidden_states = self.ffn_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
output_dict = {'input_seq': hidden_states, 'memory': cross_attn_states, 'memory_mask': cross_attn_mask}
if(output_attentions):
if(cross_attn_states is not None):
# return (hidden_states, (dec_attn_weight, cross_attn_weight))
output_dict['dec_attn_weight'] = dec_attn_weight
output_dict['cross_attn_weight'] = cross_attn_weight
return output_dict
else:
# return (hidden_states, dec_attn_weight)
output_dict['dec_attn_weight'] = dec_attn_weight
return output_dict
else:
return output_dict
def sinusoidal_positional_embedding(
x: Tensor,
n: float = 10000.0) -> Tensor:
device, dtype = x.device, x.dtype
T = x.shape[-2]
D = x.shape[-1]
positions = torch.arange(0, T, device=device, dtype=dtype).unsqueeze_(1)
embeddings = torch.zeros(T, D, device=device, dtype=dtype)
denominators = torch.pow(n, 2*torch.arange(0, D//2, device=device, dtype=dtype)/D) # 10000^(2i/d_model), i is the index of embedding
embeddings[:, 0::2] = torch.sin(positions/denominators) # sin(pos/10000^(2i/d_model))
embeddings[:, 1::2] = torch.cos(positions/denominators) # cos(pos/10000^(2i/d_model))
return x + embeddings
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
# 创建一个简单的数据集来测试收敛性
class SimpleSeq2SeqDataset(Dataset):
def __init__(self, num_samples=1000, seq_len=10, vocab_size=100, hidden_size=256):
self.num_samples = num_samples
self.seq_len = seq_len
self.vocab_size = vocab_size
self.hidden_size = hidden_size
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 模拟编码器输出 (memory)
memory = torch.randn(self.seq_len, self.hidden_size)
# 模拟解码器输入和目标
decoder_input = torch.randint(0, self.vocab_size, (self.seq_len,))
target = torch.randint(0, self.vocab_size, (self.seq_len,))
# 创建简单的注意力掩码
memory_mask = torch.ones(self.seq_len, self.seq_len)
return {
'memory': memory,
'decoder_input': decoder_input,
'target': target,
'memory_mask': memory_mask
}
# 创建简单的模型
class SimpleDecoderModel(nn.Module):
def __init__(self, vocab_size, hidden_size=256, num_layers=3):
super(SimpleDecoderModel, self).__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
# 嵌入层
self.embedding = nn.Embedding(vocab_size, hidden_size)
# 位置编码
self.pos_encoding = sinusoidal_positional_embedding
# 多个解码器层
self.layers = nn.ModuleList([
DecoderLayerWithCrossAttention(
hidden_size=hidden_size,
attn_head_dim=64,
num_attention_heads=2,
attention_dropout=0.1,
attn_bias=False,
mlp_intermediate_size=512,
mlp_dropout=0.1,
mlp_act='silu',
rms_norm_eps=1e-6,
attn_implementation="sdpa"
) for _ in range(num_layers)
])
# 输出层
self.output_norm = nn.RMSNorm(hidden_size, eps=1e-6)
self.output_proj = nn.Linear(hidden_size, vocab_size)
def forward(self, decoder_input, memory, memory_mask):
# 嵌入 + 位置编码
x = self.embedding(decoder_input)
x = self.pos_encoding(x)
# 通过多个解码器层
input_dict = {
'input_seq': x,
'memory': memory,
'memory_mask': memory_mask
}
for layer in self.layers:
output_dict = layer(input_dict)
input_dict['input_seq'] = output_dict['input_seq']
# 最终输出
hidden_states = self.output_norm(output_dict['input_seq'])
logits = self.output_proj(hidden_states)
return logits
# 训练函数
def train_decoder_model():
# 超参数
vocab_size = 100
hidden_size = 256
seq_len = 10
batch_size = 32
num_epochs = 10
learning_rate = 0.001
# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 创建数据集和数据加载器
dataset = SimpleSeq2SeqDataset(num_samples=1000, seq_len=seq_len,
vocab_size=vocab_size, hidden_size=hidden_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 初始化模型
model = SimpleDecoderModel(vocab_size=vocab_size, hidden_size=hidden_size)
model.to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 训练记录
train_losses = []
print("开始训练...")
model.train()
for epoch in range(num_epochs):
epoch_loss = 0.0
num_batches = 0
for batch in dataloader:
# 移动到设备
memory = batch['memory'].to(device)
decoder_input = batch['decoder_input'].long().to(device)
target = batch['target'].long().to(device)
memory_mask = batch['memory_mask'].to(device)
# 前向传播
optimizer.zero_grad()
logits = model(decoder_input, memory, memory_mask)
# 计算损失 - 将logits和target重塑为(batch_size * seq_len, vocab_size)和(batch_size * seq_len)
loss = criterion(logits.reshape(-1, vocab_size), target.reshape(-1))
# 反向传播
loss.backward()
optimizer.step()
epoch_loss += loss.item()
num_batches += 1
avg_loss = epoch_loss / num_batches
train_losses.append(avg_loss)
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
# 早期停止检查:如果损失明显下降,说明模型在收敛
if epoch >= 2 and avg_loss < 4.0: # 交叉熵损失的合理阈值
print("✅ 模型正常收敛!")
return True
# 检查最终损失
final_loss = train_losses[-1]
if final_loss < 5.0:
print("✅ 训练完成,模型正常收敛!")
return True
else:
print("❌ 损失下降不明显,可能需要调整模型或超参数")
return False
# 运行训练
if __name__ == "__main__":
# 先测试单个decoder层的前向传播
print("测试DecoderLayerWithCrossAttention前向传播...")
# 创建测试数据
batch_size, seq_len, hidden_size = 2, 64, 256
decoder_input = torch.randn(seq_len, hidden_size)
memory = torch.randn(seq_len, hidden_size)
memory_mask = torch.ones(seq_len, seq_len)
# 测试单层
decoder_layer = DecoderLayerWithCrossAttention(
hidden_size=hidden_size,
attn_head_dim=64,
num_attention_heads=4
)
input_dict = {
'input_seq': decoder_input,
'memory': memory,
'memory_mask': memory_mask
}
try:
output = decoder_layer(input_dict)
print("✅ DecoderLayer前向传播测试通过")
print(f"输入形状: {decoder_input.shape}")
print(f"输出形状: {output['input_seq'].shape}")
# 运行完整训练
success = train_decoder_model()
if success:
print("\n🎉 测试完成DecoderLayerWithCrossAttention能够正常训练和收敛")
else:
print("\n⚠️ 训练过程中可能存在问题,建议检查模型结构或超参数")
except Exception as e:
print(f"❌ 前向传播测试失败: {e}")