CSSM / method /Model.py
ElmanGhazaei's picture
Upload 41 files
b59f460 verified
from torch import nn
import torch
from method.MambaCSSM import MambaCSSM
class MambaCSSMUnet(nn.Module):
def __init__(self, output_classes = 2):
super(MambaCSSMUnet, self).__init__()
#### Encoder Conv
self.conv_block_1 = nn.Sequential(
nn.Conv2d(6, 16, 3, 1, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 16, 3, 1, padding=1),
nn.BatchNorm2d(16),
nn.ReLU()
)
self.mp_block_1 = nn.MaxPool2d(2, 2, return_indices=True)
self.conv_block_2 = nn.Sequential(
nn.Conv2d(16, 32, 3, 1, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, padding=1),
nn.BatchNorm2d(32),
nn.ReLU()
)
self.mp_block_2 = nn.MaxPool2d(2, 2, return_indices=True)
self.conv_block_3 = nn.Sequential(
nn.Conv2d(32, 64, 3, 1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, 3, 1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.mp_block_3 = nn.MaxPool2d(2, 2, return_indices=True)
self.conv_block_4 = nn.Sequential(
nn.Conv2d(64, 128, 3, 1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 128, 3, 1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.mp_block_4 = nn.MaxPool2d(2, 2, return_indices=True)
#### Mamba
self.mamba = MambaCSSM(num_layers=4, d_model=256,d_conv=4, d_state=16)
#### Decoder Deconv
self.mpu_block_4 = nn.MaxUnpool2d(2, 2)
self.conv_4 = nn.Sequential(
nn.Conv2d(256, 128, 3, 1, padding=1),
nn.ReLU()
)
self.deconv_4_block = nn.Sequential(
nn.ConvTranspose2d(128, 64, 3, 1, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 64, 3, 1, padding=1),
nn.ReLU()
)
self.mpu_block_3 = nn.MaxUnpool2d(2, 2)
self.conv_3 = nn.Sequential(
nn.Conv2d(128, 64, 3, 1, padding=1),
nn.ReLU()
)
self.deconv_3_block = nn.Sequential(
nn.ConvTranspose2d(64, 32, 3, 1, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 32, 3, 1, padding=1),
nn.ReLU()
)
self.mpu_block_2 = nn.MaxUnpool2d(2, 2)
self.conv_2 = nn.Sequential(
nn.Conv2d(64, 32, 3, 1, padding=1),
nn.ReLU()
)
self.deconv_2_block = nn.Sequential(
nn.ConvTranspose2d(32, 16, 3, 1, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(16, 16, 3, 1, padding=1),
nn.ReLU()
)
self.mpu_block_1 = nn.MaxUnpool2d(2, 2)
self.conv_1 = nn.Sequential(
nn.Conv2d(32, 16, 3, 1, padding=1),
nn.ReLU()
)
self.deconv_1_block = nn.Sequential(
nn.ConvTranspose2d(16, 8, 3, 1, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(8, 6, 3, 1, padding=1),
nn.ReLU()
)
self.conv_final = nn.Conv2d(6, output_classes, 1, 1)
def forward(self, t1,t2):
t = torch.cat([t1,t2], dim = 1)
x1 = self.conv_block_1(t)
f1, i1 = self.mp_block_1(x1)
x2 = self.conv_block_2(f1)
f2, i2 = self.mp_block_2(x2)
x3 = self.conv_block_3(f2)
f3, i3 = self.mp_block_3(x3)
x4 = self.conv_block_4(f3)
f4, i4 = self.mp_block_4(x4)
b,c,h,w = f4.shape
f4_t1 = f4[:,:c//2, :,:]
f4_t2 = f4[:,c//2:, :,:]
# print(f4_t1.shape)
f4_t1 = f4_t1.view((-1, 64, 16*16)) # Adjusted for input size 256x256
f4_t2 = f4_t2.view((-1, 64, 16*16)) # Adjusted for input size 256x256
f5_t1,f5_t2 = self.mamba(f4_t1, f4_t2)
f5_t1 = f5_t1.view((-1, 64, 16, 16)) # Adjust the shape for further operations
f5_t2 = f5_t2.view((-1, 64, 16, 16)) # Adjust the shape for further operations
f5 = torch.cat([f5_t1, f5_t2], dim = 1)
f6 = self.mpu_block_4(f5, i4)
f7 = self.conv_4(torch.cat((x4, f6), dim=1))
f8 = self.deconv_4_block(f7)
f9 = self.mpu_block_3(f8, i3, output_size=x3.size())
f10 = self.conv_3(torch.cat((f9, x3), dim=1))
f11 = self.deconv_3_block(f10)
f12 = self.mpu_block_2(f11, i2)
f13 = self.conv_2(torch.cat((f12, x2), dim=1))
f14 = self.deconv_2_block(f13)
f15 = self.mpu_block_1(f14, i1)
f16 = self.conv_1(torch.cat((f15, x1), dim=1))
f17 = self.deconv_1_block(f16)
f18 = self.conv_final(f17)
return f18