|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
class MultiHeadCoAttention(nn.Module): |
|
|
def __init__(self, multi_dim, single_dim, num_heads): |
|
|
assert multi_dim % num_heads == 0, 'multi_dim must be divisible by num_heads' |
|
|
assert single_dim % num_heads == 0, 'single_dim must be divisible by num_heads' |
|
|
super().__init__() |
|
|
self.q_proj = nn.Linear(single_dim, single_dim) |
|
|
self.k_proj = nn.Linear(single_dim, single_dim) |
|
|
self.multi_v_proj = nn.Linear(multi_dim, multi_dim) |
|
|
self.single_v_proj = nn.Linear(single_dim, single_dim) |
|
|
|
|
|
self.multi_out_proj = nn.Linear(multi_dim, multi_dim) |
|
|
self.single_out_proj = nn.Linear(single_dim, single_dim) |
|
|
|
|
|
self.multi_dim = multi_dim |
|
|
self.single_dim = single_dim |
|
|
self.num_heads = num_heads |
|
|
|
|
|
def forward(self, query, key, multi_value, single_value): |
|
|
|
|
|
|
|
|
query = torch.transpose(query, 0, 1) |
|
|
key = torch.transpose(key, 0, 1) |
|
|
multi_value = torch.permute(multi_value, (1, 2, 0, 3)) |
|
|
single_value = torch.permute(single_value, (1, 2, 0, 3)) |
|
|
|
|
|
|
|
|
q = torch.split(self.q_proj(query), self.single_dim // self.num_heads, dim=-1) |
|
|
q = torch.stack(q, dim=1) |
|
|
|
|
|
k = torch.split(self.k_proj(key), self.single_dim // self.num_heads, dim=-1) |
|
|
k = torch.stack(k, dim=1) |
|
|
|
|
|
multi_v = torch.split(self.multi_v_proj(multi_value), self.multi_dim // self.num_heads, |
|
|
dim=-1) |
|
|
multi_v = torch.stack(multi_v, dim=1) |
|
|
|
|
|
single_v = torch.split(self.single_v_proj(single_value), self.single_dim // self.num_heads, |
|
|
dim=-1) |
|
|
single_v = torch.stack(single_v, dim=1) |
|
|
|
|
|
q = q.view(*q.shape[:-2], -1) |
|
|
k = k.view(*k.shape[:-2], -1) |
|
|
normalizer = torch.sqrt(torch.Tensor([float(q.shape[-1])]).to(q.device)) |
|
|
|
|
|
sim_mat = torch.matmul(q, torch.transpose(k, -2, -1)) / normalizer |
|
|
att_mat = torch.unsqueeze(nn.functional.softmax(sim_mat, dim=-1), 2) |
|
|
|
|
|
|
|
|
multi_result = torch.matmul(att_mat, multi_v) |
|
|
single_result = torch.matmul(att_mat, single_v) |
|
|
|
|
|
multi_result = torch.permute(multi_result, (3, 0, 2, 1, 4)) |
|
|
single_result = torch.permute(single_result, (3, 0, 2, 1, 4)) |
|
|
multi_result = torch.reshape(multi_result, multi_result.shape[:-2] + (-1,)) |
|
|
single_result = torch.reshape(single_result, single_result.shape[:-2] + (-1,)) |
|
|
|
|
|
multi_result = self.multi_out_proj(multi_result) |
|
|
single_result = self.single_out_proj(single_result) |
|
|
return multi_result, single_result |
|
|
|
|
|
|
|
|
class CoAttention(nn.Module): |
|
|
def __init__(self, embed_dim=768, single_dim=256, multi_dim=64, n_heads=8, attn_dropout=0., |
|
|
init_mult=1e-2): |
|
|
super().__init__() |
|
|
self.init_mult = init_mult |
|
|
|
|
|
self.in_single_proj = nn.Linear(embed_dim, single_dim) |
|
|
self.in_single_ln = nn.LayerNorm(single_dim) |
|
|
|
|
|
self.in_multi_proj = nn.Linear(embed_dim, multi_dim) |
|
|
self.in_multi_ln = nn.LayerNorm(multi_dim) |
|
|
|
|
|
self.mca = MultiHeadCoAttention(multi_dim, single_dim, n_heads) |
|
|
self.mca_multi_out_ln = nn.LayerNorm(multi_dim) |
|
|
self.mca_single_out_ln = nn.LayerNorm(single_dim) |
|
|
|
|
|
|
|
|
self.cross_frame_mha = nn.MultiheadAttention(single_dim, n_heads, dropout=attn_dropout, bias=True, kdim=None, |
|
|
vdim=None) |
|
|
self.mha_ln = nn.LayerNorm(single_dim) |
|
|
|
|
|
self.cat_proj = nn.Linear(single_dim + multi_dim, embed_dim) |
|
|
|
|
|
self.miso = False |
|
|
|
|
|
def scale_weights(self): |
|
|
self.cat_proj.bias.data *= 0. |
|
|
self.cat_proj.weight.data *= self.init_mult |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
frames, B, chans, feat_dim = x.shape |
|
|
|
|
|
single_x = torch.mean(x,dim=2) |
|
|
single_x = self.in_single_ln(self.in_single_proj(single_x)).unsqueeze(dim=-2) |
|
|
|
|
|
multi_x = self.in_multi_ln(self.in_multi_proj(x)) |
|
|
|
|
|
|
|
|
multi_mca, single_mca = self.mca(single_x, single_x, multi_x, single_x) |
|
|
single_x = single_x + single_mca |
|
|
multi_x = multi_x + multi_mca |
|
|
multi_x = self.mca_multi_out_ln(multi_x) |
|
|
single_x = torch.squeeze(self.mca_single_out_ln(single_x), -2) |
|
|
|
|
|
|
|
|
single_mha, _ = self.cross_frame_mha(single_x, single_x, single_x, need_weights=False) |
|
|
single_x = self.mha_ln(single_mha + single_x) |
|
|
|
|
|
|
|
|
single_x = single_x.unsqueeze(-2) |
|
|
single_x_tile = torch.tile(single_x, (1, 1, chans, 1)) |
|
|
cat_x = torch.cat([single_x_tile, multi_x], dim=-1) |
|
|
out = self.cat_proj(cat_x) |
|
|
|
|
|
return out |