Files
MIDIFoundationModel/Amadeus/toy_train.py
2025-11-27 15:44:17 +08:00

454 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from math import ceil
from typing import Optional, Union, Literal
from typing_extensions import Unpack
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from transformers import Qwen2Config
from transformers.activations import get_activation
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, x: Tensor) -> Tensor:
dtype = x.dtype
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
x = self.weight * x
return x.to(dtype)
class FeedForward(nn.Module):
def __init__(self,
hidden_size: int,
intermediate_size: Optional[int] = None,
dropout: float = 0.,
hidden_act: Literal['silu', 'gelu'] = 'silu'):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = get_activation(hidden_act)
self.dropout = nn.Dropout(dropout)
def forward(self, x: Tensor) -> Tensor:
gate, hidden = self.up_proj(x).chunk(2, dim=-1)
gate = self.act_fn(gate)
return self.down_proj(self.dropout(gate * hidden))
class MultiHeadAttention(nn.Module):
def __init__(self,
hidden_size: int,
head_dim: int = 64,
num_attention_heads: Optional[int] = None,
attention_dropout: float = 0.,
bias: bool = False,
rms_norm_eps: float = 1e-6,
attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa"
):
super().__init__()
self.num_attention_heads = num_attention_heads or int(hidden_size // head_dim)
self.head_dim = head_dim
self.scaling = head_dim ** -0.5
self.attention_dropout = attention_dropout
self.q_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
bias=bias)
self.k_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
bias=bias)
self.v_proj = nn.Linear(hidden_size, self.num_attention_heads * self.head_dim,
bias=bias)
self.o_proj = nn.Linear(
self.num_attention_heads * self.head_dim, hidden_size,
bias=bias)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.attention_interface = ALL_ATTENTION_FUNCTIONS[attn_implementation]
# To be compatible with huggingface
self.config = Qwen2Config()
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
key_states: Optional[Tensor] = None,
value_states: Optional[Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
q = self.q_proj(hidden_states)
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(q.view(hidden_shape)).transpose(1, 2)
if(key_states is None or value_states is None):
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
key_states = self.k_norm(k.view(hidden_shape)).transpose(1, 2)
value_states = v.view(hidden_shape).transpose(1, 2)
attn_output, attn_weights = self.attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class DecoderLayerWithCrossAttention(nn.Module):
def __init__(self,
hidden_size: int,
attn_head_dim: int = 64,
num_attention_heads: Optional[int] = None,
attention_dropout: float = 0.,
attn_bias: bool = False,
mlp_intermediate_size: Optional[int] = None,
mlp_dropout: float = 0.,
mlp_act: Literal['gelu', 'silu'] = 'silu',
rms_norm_eps: float = 1e-6,
attn_implementation: Literal["flash_attention_3", "flash_attention_2", "flex_attention", "paged_attention", "sdpa", "sdpa_paged", "eager_paged"] = "sdpa"
):
super().__init__()
self.hidden_size = hidden_size
self.self_attn = MultiHeadAttention(
hidden_size,
attn_head_dim,
num_attention_heads,
attention_dropout,
attn_bias,
rms_norm_eps,
attn_implementation)
self.cross_attn = MultiHeadAttention(
hidden_size,
attn_head_dim,
num_attention_heads,
attention_dropout,
attn_bias,
rms_norm_eps,
attn_implementation)
self.mlp = FeedForward(
hidden_size,
mlp_intermediate_size,
mlp_dropout,
mlp_act)
self.self_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
self.cross_attn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
self.ffn_layernorm = nn.RMSNorm(hidden_size, eps=rms_norm_eps)
def forward(
self,
input_dict,
):
'''
input_dict = {'input_seq': input_seq, 'memory': memory, 'memory_mask': CA_attn_mask}
'''
hidden_states = input_dict['input_seq']
cross_attn_states = input_dict['memory']
cross_attn_mask = input_dict['memory_mask']
attention_mask = input_dict['memory_mask']
output_attentions = False
# Self Attention
residual = hidden_states
hidden_states = self.self_attn_layernorm(hidden_states)
hidden_states, dec_attn_weight = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
hidden_states = residual + hidden_states
# Cross Attention
if(cross_attn_states is not None):
residual = hidden_states
hidden_states = self.cross_attn_layernorm(hidden_states)
cross_attn_shape = (*cross_attn_states.shape[:-1], -1, self.cross_attn.head_dim)
key_states = self.cross_attn.k_proj(cross_attn_states)
value_states = self.cross_attn.v_proj(cross_attn_states)
key_states = self.cross_attn.k_norm(key_states.view(cross_attn_shape)).transpose(1, 2)
value_states = value_states.view(cross_attn_shape).transpose(1, 2)
hidden_states, cross_attn_weight = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attn_mask,
key_states=key_states,
value_states=value_states
)
hidden_states = residual + hidden_states
# Feed Forward
residual = hidden_states
hidden_states = self.ffn_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
output_dict = {'input_seq': hidden_states, 'memory': cross_attn_states, 'memory_mask': cross_attn_mask}
if(output_attentions):
if(cross_attn_states is not None):
# return (hidden_states, (dec_attn_weight, cross_attn_weight))
output_dict['dec_attn_weight'] = dec_attn_weight
output_dict['cross_attn_weight'] = cross_attn_weight
return output_dict
else:
# return (hidden_states, dec_attn_weight)
output_dict['dec_attn_weight'] = dec_attn_weight
return output_dict
else:
return output_dict
def sinusoidal_positional_embedding(
x: Tensor,
n: float = 10000.0) -> Tensor:
device, dtype = x.device, x.dtype
T = x.shape[-2]
D = x.shape[-1]
positions = torch.arange(0, T, device=device, dtype=dtype).unsqueeze_(1)
embeddings = torch.zeros(T, D, device=device, dtype=dtype)
denominators = torch.pow(n, 2*torch.arange(0, D//2, device=device, dtype=dtype)/D) # 10000^(2i/d_model), i is the index of embedding
embeddings[:, 0::2] = torch.sin(positions/denominators) # sin(pos/10000^(2i/d_model))
embeddings[:, 1::2] = torch.cos(positions/denominators) # cos(pos/10000^(2i/d_model))
return x + embeddings
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
# 创建一个简单的数据集来测试收敛性
class SimpleSeq2SeqDataset(Dataset):
def __init__(self, num_samples=1000, seq_len=10, vocab_size=100, hidden_size=256):
self.num_samples = num_samples
self.seq_len = seq_len
self.vocab_size = vocab_size
self.hidden_size = hidden_size
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 模拟编码器输出 (memory)
memory = torch.randn(self.seq_len, self.hidden_size)
# 模拟解码器输入和目标
decoder_input = torch.randint(0, self.vocab_size, (self.seq_len,))
target = torch.randint(0, self.vocab_size, (self.seq_len,))
# 创建简单的注意力掩码
memory_mask = torch.ones(self.seq_len, self.seq_len)
return {
'memory': memory,
'decoder_input': decoder_input,
'target': target,
'memory_mask': memory_mask
}
# 创建简单的模型
class SimpleDecoderModel(nn.Module):
def __init__(self, vocab_size, hidden_size=256, num_layers=3):
super(SimpleDecoderModel, self).__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
# 嵌入层
self.embedding = nn.Embedding(vocab_size, hidden_size)
# 位置编码
self.pos_encoding = sinusoidal_positional_embedding
# 多个解码器层
self.layers = nn.ModuleList([
DecoderLayerWithCrossAttention(
hidden_size=hidden_size,
attn_head_dim=64,
num_attention_heads=2,
attention_dropout=0.1,
attn_bias=False,
mlp_intermediate_size=512,
mlp_dropout=0.1,
mlp_act='silu',
rms_norm_eps=1e-6,
attn_implementation="sdpa"
) for _ in range(num_layers)
])
# 输出层
self.output_norm = nn.RMSNorm(hidden_size, eps=1e-6)
self.output_proj = nn.Linear(hidden_size, vocab_size)
def forward(self, decoder_input, memory, memory_mask):
# 嵌入 + 位置编码
x = self.embedding(decoder_input)
x = self.pos_encoding(x)
# 通过多个解码器层
input_dict = {
'input_seq': x,
'memory': memory,
'memory_mask': memory_mask
}
for layer in self.layers:
output_dict = layer(input_dict)
input_dict['input_seq'] = output_dict['input_seq']
# 最终输出
hidden_states = self.output_norm(output_dict['input_seq'])
logits = self.output_proj(hidden_states)
return logits
# 训练函数
def train_decoder_model():
# 超参数
vocab_size = 100
hidden_size = 256
seq_len = 10
batch_size = 32
num_epochs = 10
learning_rate = 0.001
# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 创建数据集和数据加载器
dataset = SimpleSeq2SeqDataset(num_samples=1000, seq_len=seq_len,
vocab_size=vocab_size, hidden_size=hidden_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 初始化模型
model = SimpleDecoderModel(vocab_size=vocab_size, hidden_size=hidden_size)
model.to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 训练记录
train_losses = []
print("开始训练...")
model.train()
for epoch in range(num_epochs):
epoch_loss = 0.0
num_batches = 0
for batch in dataloader:
# 移动到设备
memory = batch['memory'].to(device)
decoder_input = batch['decoder_input'].long().to(device)
target = batch['target'].long().to(device)
memory_mask = batch['memory_mask'].to(device)
# 前向传播
optimizer.zero_grad()
logits = model(decoder_input, memory, memory_mask)
# 计算损失 - 将logits和target重塑为(batch_size * seq_len, vocab_size)和(batch_size * seq_len)
loss = criterion(logits.reshape(-1, vocab_size), target.reshape(-1))
# 反向传播
loss.backward()
optimizer.step()
epoch_loss += loss.item()
num_batches += 1
avg_loss = epoch_loss / num_batches
train_losses.append(avg_loss)
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
# 早期停止检查:如果损失明显下降,说明模型在收敛
if epoch >= 2 and avg_loss < 4.0: # 交叉熵损失的合理阈值
print("✅ 模型正常收敛!")
return True
# 检查最终损失
final_loss = train_losses[-1]
if final_loss < 5.0:
print("✅ 训练完成,模型正常收敛!")
return True
else:
print("❌ 损失下降不明显,可能需要调整模型或超参数")
return False
# 运行训练
if __name__ == "__main__":
# 先测试单个decoder层的前向传播
print("测试DecoderLayerWithCrossAttention前向传播...")
# 创建测试数据
batch_size, seq_len, hidden_size = 2, 64, 256
decoder_input = torch.randn(seq_len, hidden_size)
memory = torch.randn(seq_len, hidden_size)
memory_mask = torch.ones(seq_len, seq_len)
# 测试单层
decoder_layer = DecoderLayerWithCrossAttention(
hidden_size=hidden_size,
attn_head_dim=64,
num_attention_heads=4
)
input_dict = {
'input_seq': decoder_input,
'memory': memory,
'memory_mask': memory_mask
}
try:
output = decoder_layer(input_dict)
print("✅ DecoderLayer前向传播测试通过")
print(f"输入形状: {decoder_input.shape}")
print(f"输出形状: {output['input_seq'].shape}")
# 运行完整训练
success = train_decoder_model()
if success:
print("\n🎉 测试完成DecoderLayerWithCrossAttention能够正常训练和收敛")
else:
print("\n⚠️ 训练过程中可能存在问题,建议检查模型结构或超参数")
except Exception as e:
print(f"❌ 前向传播测试失败: {e}")