67 lines
1.7 KiB
Python
67 lines
1.7 KiB
Python
from einops import rearrange
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Generator(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_features,
|
|
ffd_hidden_size,
|
|
num_classes,
|
|
attn_layer_num,
|
|
|
|
):
|
|
super(Generator, self).__init__()
|
|
|
|
self.attn = nn.ModuleList(
|
|
[
|
|
nn.MultiheadAttention(
|
|
embed_dim=in_features,
|
|
num_heads=8,
|
|
dropout=0.2,
|
|
batch_first=True,
|
|
)
|
|
for _ in range(attn_layer_num)
|
|
]
|
|
)
|
|
|
|
self.ffd = nn.Sequential(
|
|
nn.Linear(in_features, ffd_hidden_size),
|
|
nn.ReLU(),
|
|
nn.Linear(ffd_hidden_size, in_features)
|
|
)
|
|
|
|
self.dropout = nn.Dropout(0.2)
|
|
|
|
self.fc = nn.Linear(in_features * 2, num_classes)
|
|
|
|
self.proj = nn.Tanh()
|
|
|
|
|
|
def forward(self, ssl_feature, judge_id=None):
|
|
'''
|
|
ssl_feature: [B, T, D]
|
|
output: [B, num_classes]
|
|
'''
|
|
|
|
B, T, D = ssl_feature.shape
|
|
|
|
ssl_feature = self.ffd(ssl_feature)
|
|
|
|
tmp_ssl_feature = ssl_feature
|
|
|
|
for attn in self.attn:
|
|
tmp_ssl_feature, _ = attn(tmp_ssl_feature, tmp_ssl_feature, tmp_ssl_feature)
|
|
|
|
ssl_feature = self.dropout(torch.concat([torch.mean(tmp_ssl_feature, dim=1), torch.max(ssl_feature, dim=1)[0]], dim=1)) # B, 2D
|
|
|
|
x = self.fc(ssl_feature) # B, num_classes
|
|
|
|
x = self.proj(x) * 2.0 + 3
|
|
|
|
return x
|
|
|
|
|