DiCoW_v3_2 / coattention.py
Lakoc's picture
Upload DiCoWForConditionalGeneration
c00ff2c verified
import torch
from torch import nn
class MultiHeadCoAttention(nn.Module):
def __init__(self, multi_dim, single_dim, num_heads):
assert multi_dim % num_heads == 0, 'multi_dim must be divisible by num_heads'
assert single_dim % num_heads == 0, 'single_dim must be divisible by num_heads'
super().__init__()
self.q_proj = nn.Linear(single_dim, single_dim)
self.k_proj = nn.Linear(single_dim, single_dim)
self.multi_v_proj = nn.Linear(multi_dim, multi_dim) # D'
self.single_v_proj = nn.Linear(single_dim, single_dim) # D
self.multi_out_proj = nn.Linear(multi_dim, multi_dim) # D'
self.single_out_proj = nn.Linear(single_dim, single_dim) # D
self.multi_dim = multi_dim
self.single_dim = single_dim
self.num_heads = num_heads
def forward(self, query, key, multi_value, single_value):
# q, k, multi_v: (T,B,ch,D')
# single_v: (T,B,1,D)
query = torch.transpose(query, 0, 1) # (B,T,ch,D')...[32, 150, 4, 64]
key = torch.transpose(key, 0, 1) # (B,T,ch,D')...[32, 150, 4, 64]
multi_value = torch.permute(multi_value, (1, 2, 0, 3)) # (B,ch,T,D')...[32, 4, 150, 64]
single_value = torch.permute(single_value, (1, 2, 0, 3)) # (B,1,T,D)...[32, 1, 150, 256]
###########
q = torch.split(self.q_proj(query), self.single_dim // self.num_heads, dim=-1) # seq: (B,T,ch,D'/h)
q = torch.stack(q, dim=1) # (B,h,T,ch,D'/h)...[32, 8, 150, 4, 8]
k = torch.split(self.k_proj(key), self.single_dim // self.num_heads, dim=-1) # seq: (B,T,ch,D'/h)
k = torch.stack(k, dim=1) # (B,h,T,ch,D'/h)...[32, 8, 150, 4, 8]
multi_v = torch.split(self.multi_v_proj(multi_value), self.multi_dim // self.num_heads,
dim=-1) # seq: (B,ch,T,D'/h)
multi_v = torch.stack(multi_v, dim=1) # (B, h, ch, T, D'/h)...[32, 8, 4, 150, 8]
single_v = torch.split(self.single_v_proj(single_value), self.single_dim // self.num_heads,
dim=-1) # seq: (B,1,T,D/h)
single_v = torch.stack(single_v, dim=1) # seq: (B,h,1,T,D/h)...[32, 32, 1, 150, 8]
q = q.view(*q.shape[:-2], -1) # (B, h, T, ch*D/h)
k = k.view(*k.shape[:-2], -1) # (B, h, T, ch*D/h)
normalizer = torch.sqrt(torch.Tensor([float(q.shape[-1])]).to(q.device))
sim_mat = torch.matmul(q, torch.transpose(k, -2, -1)) / normalizer # (B, h, T, T)
att_mat = torch.unsqueeze(nn.functional.softmax(sim_mat, dim=-1), 2) # (B, h, 1, T, T)
# co-attention
multi_result = torch.matmul(att_mat, multi_v) # (B, h, ch, T, D'/h)
single_result = torch.matmul(att_mat, single_v) # (B, h, 1, T, D/h)
multi_result = torch.permute(multi_result, (3, 0, 2, 1, 4)) # (T, B, ch, h, D'/h)
single_result = torch.permute(single_result, (3, 0, 2, 1, 4)) # (T, B, 1, h, D/h)
multi_result = torch.reshape(multi_result, multi_result.shape[:-2] + (-1,)) # (T, B, ch, D')
single_result = torch.reshape(single_result, single_result.shape[:-2] + (-1,)) # (T, B, 1, D)
multi_result = self.multi_out_proj(multi_result)
single_result = self.single_out_proj(single_result)
return multi_result, single_result
class CoAttention(nn.Module):
def __init__(self, embed_dim=768, single_dim=256, multi_dim=64, n_heads=8, attn_dropout=0.,
init_mult=1e-2): # , pre_norm=True):
super().__init__()
self.init_mult = init_mult
self.in_single_proj = nn.Linear(embed_dim, single_dim) # single_dim == D
self.in_single_ln = nn.LayerNorm(single_dim)
self.in_multi_proj = nn.Linear(embed_dim, multi_dim) # multi_dim == D'
self.in_multi_ln = nn.LayerNorm(multi_dim)
self.mca = MultiHeadCoAttention(multi_dim, single_dim, n_heads)
self.mca_multi_out_ln = nn.LayerNorm(multi_dim)
self.mca_single_out_ln = nn.LayerNorm(single_dim)
# default MHA input: (seq, batch, feature)
self.cross_frame_mha = nn.MultiheadAttention(single_dim, n_heads, dropout=attn_dropout, bias=True, kdim=None,
vdim=None)
self.mha_ln = nn.LayerNorm(single_dim)
self.cat_proj = nn.Linear(single_dim + multi_dim, embed_dim)
self.miso = False
def scale_weights(self):
self.cat_proj.bias.data *= 0.
self.cat_proj.weight.data *= self.init_mult
def forward(self, x):
# x: (T,B,ch,F); (150, 32, 4, 768)
frames, B, chans, feat_dim = x.shape
single_x = torch.mean(x,dim=2) # (T,B,F)
single_x = self.in_single_ln(self.in_single_proj(single_x)).unsqueeze(dim=-2) # (T,B,1,D)
multi_x = self.in_multi_ln(self.in_multi_proj(x)) # (T,B,ch,D')
# MCA
multi_mca, single_mca = self.mca(single_x, single_x, multi_x, single_x) # (T,B,ch,D'), (T,B,ch,D)
single_x = single_x + single_mca
multi_x = multi_x + multi_mca
multi_x = self.mca_multi_out_ln(multi_x) # (T,B,ch,D')
single_x = torch.squeeze(self.mca_single_out_ln(single_x), -2) # (T,B,D)
# MHA
single_mha, _ = self.cross_frame_mha(single_x, single_x, single_x, need_weights=False) # (T, B, D)
single_x = self.mha_ln(single_mha + single_x)
# join representations
single_x = single_x.unsqueeze(-2) # (T,B,1,D)
single_x_tile = torch.tile(single_x, (1, 1, chans, 1)) # (T,B,ch,D)
cat_x = torch.cat([single_x_tile, multi_x], dim=-1) # (T,B,ch,D+D')
out = self.cat_proj(cat_x) # (T,B,ch,F)
return out