ElmanGhazaei commited on
Commit
b59f460
·
verified ·
1 Parent(s): 0914649

Upload 41 files

Browse files
Files changed (41) hide show
  1. README.md +70 -0
  2. data_maker/__init__.py +0 -0
  3. data_maker/data_provider.py +178 -0
  4. main.py +247 -0
  5. method/CSSM.ipynb +210 -0
  6. method/MambaCSSM.py +382 -0
  7. method/Model.py +173 -0
  8. method/__pycache__/Mamba.cpython-313.pyc +0 -0
  9. method/__pycache__/MambaCSSM.cpython-313.pyc +0 -0
  10. method/__pycache__/Model.cpython-313.pyc +0 -0
  11. pre_trained_weights/LEVIR+/levir_cd_+_cssm.pth +3 -0
  12. pre_trained_weights/LEVIR+/levir_layer_1.pth +3 -0
  13. pre_trained_weights/LEVIR+/levir_layer_2.pth +3 -0
  14. pre_trained_weights/LEVIR+/levir_layer_3.pth +3 -0
  15. pre_trained_weights/LEVIR+/levir_layer_4.pth +3 -0
  16. pre_trained_weights/LEVIR+/levir_layer_6.pth +3 -0
  17. pre_trained_weights/SYSU-CD/sysu.pth +3 -0
  18. pre_trained_weights/SYSU-CD/sysu_1.pth +3 -0
  19. pre_trained_weights/SYSU-CD/sysu_layer_1.pth +3 -0
  20. pre_trained_weights/SYSU-CD/sysu_layer_2.pth +3 -0
  21. pre_trained_weights/SYSU-CD/sysu_layer_3.pth +3 -0
  22. pre_trained_weights/SYSU-CD/sysu_layer_4.pth +3 -0
  23. pre_trained_weights/SYSU-CD/sysu_layer_5.pth +3 -0
  24. pre_trained_weights/SYSU-CD/sysu_layer_6.pth +3 -0
  25. pre_trained_weights/WHU-CD/whu.pth +3 -0
  26. pre_trained_weights/WHU-CD/whu_1.pth +3 -0
  27. pre_trained_weights/WHU-CD/whu_layer_1.pth +3 -0
  28. pre_trained_weights/WHU-CD/whu_layer_2.pth +3 -0
  29. pre_trained_weights/WHU-CD/whu_layer_3.pth +3 -0
  30. pre_trained_weights/WHU-CD/whu_layer_4.pth +3 -0
  31. pre_trained_weights/WHU-CD/whu_layer_5.pth +3 -0
  32. pre_trained_weights/WHU-CD/whu_layer_6.pth +3 -0
  33. utils/__pycache__/__init__.cpython-313.pyc +0 -0
  34. utils/__pycache__/imgutils.cpython-313.pyc +0 -0
  35. utils/__pycache__/make_data.cpython-313.pyc +0 -0
  36. utils/__pycache__/metric.cpython-313.pyc +0 -0
  37. utils/__pycache__/utils_loss.cpython-313.pyc +0 -0
  38. utils/loss/L.py +245 -0
  39. utils/loss/__pycache__/L.cpython-313.pyc +0 -0
  40. utils/metrics/__pycache__/ev.cpython-313.pyc +0 -0
  41. utils/metrics/ev.py +103 -0
README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ <div align="center">
3
+
4
+ # CSSM
5
+ **Efficient Remote Sensing Change Detection with Change State Space Models**
6
+
7
+ [**E.Ghazaei**](https://scholar.google.com/citations?user=R-ghC00AAAAJ&hl=en), [**E.Aptoula**](https://sites.google.com/view/erchan-aptoula/)
8
+
9
+ Faculty of Engineering and Natural Sciences (VPALab), Sabanci University, Istanbul, Turkiye
10
+
11
+ [[Paper Link](https://arxiv.org/abs/2504.11080)]
12
+ </div>
13
+
14
+
15
+
16
+ ## 🛎️Updates
17
+ * **` Notice🐍🐍`**: CSSM has been accepted by [IEEE GRSL](https://ieeexplore.ieee.org/xpl/RecentIssue.jsp?punumber=8859)! We'd appreciate it if you could give this repo a ⭐️**star**⭐️ and stay tuned!!
18
+ * **` Nov 05th, 2025`**: The CSSM model and training code uploaded. You are welcome to use them!!
19
+
20
+
21
+
22
+
23
+ ---
24
+
25
+
26
+ ## 🚀 Overview
27
+
28
+
29
+ * [**CSSM**]() serves as an efficient and state-of-the-art (SOTA) benchmark for binary change detection.
30
+
31
+
32
+
33
+ <p align="center">
34
+
35
+ <img width="1395" height="579" alt="Screenshot from 2025-11-03 16-28-31" src="https://github.com/user-attachments/assets/dccfdfc5-98b4-443d-b170-07e5e3ec551d" />
36
+ </p>
37
+
38
+
39
+ ---
40
+
41
+
42
+ ## Datasets
43
+ We used [LEVIR-CD+](https://www.kaggle.com/datasets/mdrifaturrahman33/levir-cd-change-detection), [SYSU-CD](https://github.com/liumency/SYSU-CD), and [WHU-CD](http://gpcv.whu.edu.cn/data/building_dataset.html) as the main datasets, while [CDD](http://gpcv.whu.edu.cn/data/building_dataset.html) and [OSCD](https://www.kaggle.com/datasets/soumikrakshit/onera-satellite-change-detection-dataset) were included in the ablation study to demonstrate the robustness of our model under different conditions.
44
+
45
+
46
+ **Qualitative Analysis:**
47
+
48
+
49
+
50
+
51
+
52
+ <p align="center">
53
+ <img width="1379" height="357" alt="Screenshot from 2025-11-03 16-38-52" src="https://github.com/user-attachments/assets/c63690af-fd07-40af-b991-2b5b33ff53af" />
54
+ </p>
55
+
56
+ ---
57
+ # Results
58
+
59
+ ![Screenshot from 2025-04-13 14-51-16](https://github.com/user-attachments/assets/36f7487a-c08b-4205-9c05-e9b909ef0c89)
60
+
61
+
62
+
63
+ # Complexity
64
+
65
+ <div align="center">
66
+
67
+ ![Screenshot from 2025-04-13 14-56-37](https://github.com/user-attachments/assets/b4b50828-fdd0-4b31-a4c2-e802ec43b404)
68
+
69
+ </div>
70
+
data_maker/__init__.py ADDED
File without changes
data_maker/data_provider.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import numpy as np
4
+ from PIL import Image
5
+ import os
6
+ from torchvision import transforms
7
+ import pandas as pd
8
+
9
+
10
+ from torch.utils.data import Dataset
11
+
12
+
13
+
14
+ tfms_normal = transforms.Compose([
15
+ transforms.CenterCrop(size=(256,256)),
16
+ transforms.ToTensor()
17
+ # transforms.Normalize(mean=[0.46,0.44,0.39], std= [0.19,0.18,0.19])
18
+ ])
19
+
20
+ tfms_target = transforms.CenterCrop(size = (256,256))
21
+
22
+
23
+ class Data_provider_SYSU(Dataset):
24
+
25
+
26
+ def __init__(self, path):
27
+ self.data_path = path
28
+ self.pre_path = os.path.join(path, "time1")
29
+ self.post_path = os.path.join(path, "time2")
30
+ self.target_path = os.path.join(path, "label")
31
+
32
+ def __len__(self):
33
+ return len(os.listdir(self.post_path))
34
+
35
+ def __getitem__(self, idx):
36
+
37
+ post_list = os.listdir(self.pre_path)
38
+ pre_list = os.listdir(self.post_path)
39
+ target_list = os.listdir(self.target_path)
40
+
41
+ post_list.sort()
42
+ pre_list.sort()
43
+ target_list.sort()
44
+
45
+ pre_image_path = os.path.join(self.pre_path, pre_list[idx])
46
+ post_image_path = os.path.join(self.post_path, post_list[idx])
47
+ target_path = os.path.join(self.target_path, target_list[idx])
48
+
49
+ pre_image = Image.open(pre_image_path)
50
+ post_image = Image.open(post_image_path)
51
+ target_image = Image.open(target_path)
52
+
53
+ pre_image = tfms_normal(pre_image)
54
+ post_image = tfms_normal(post_image)
55
+
56
+
57
+ target_image = torch.tensor(np.array(tfms_target(target_image))/255).long()
58
+
59
+
60
+
61
+ return pre_image, post_image, target_image
62
+
63
+
64
+ class Data_provider_levir(Dataset):
65
+
66
+
67
+ def __init__(self, path):
68
+ self.data_path = path
69
+ self.pre_path = os.path.join(path, "A")
70
+ self.post_path = os.path.join(path, "B")
71
+ self.target_path = os.path.join(path, "label")
72
+
73
+ def __len__(self):
74
+ return len(os.listdir(self.post_path))
75
+
76
+ def __getitem__(self, idx):
77
+
78
+ pre_list = os.listdir(self.pre_path)
79
+ post_list = os.listdir(self.post_path)
80
+ target_list = os.listdir(self.target_path)
81
+
82
+ post_list.sort()
83
+ pre_list.sort()
84
+ target_list.sort()
85
+
86
+ pre_image_path = os.path.join(self.pre_path, pre_list[idx])
87
+ post_image_path = os.path.join(self.post_path, post_list[idx])
88
+ target_path = os.path.join(self.target_path, target_list[idx])
89
+ # print(pre_image_path)
90
+ # print(post_image_path)
91
+ # print(target_path)
92
+
93
+ pre_image = Image.open(pre_image_path)
94
+ post_image = Image.open(post_image_path)
95
+ target_image = Image.open(target_path)
96
+
97
+ pre_image = tfms_normal(pre_image)
98
+ post_image = tfms_normal(post_image)
99
+
100
+
101
+ target_image = torch.tensor(np.array(target_image)/ 255).long()
102
+
103
+
104
+
105
+ return pre_image, post_image, target_image
106
+
107
+
108
+
109
+
110
+
111
+
112
+ class Data_provider_WHU(Dataset):
113
+
114
+
115
+ def __init__(self, path, file):
116
+ self.data_path = path
117
+ self.pre_path = os.path.join(path, "A")
118
+ self.post_path = os.path.join(path, "B")
119
+ self.target_path = os.path.join(path, "label")
120
+ self.data_names = np.array(pd.read_csv(file, names=["tt"]))
121
+
122
+
123
+ def __len__(self):
124
+ return len(self.data_names)
125
+
126
+ def __getitem__(self, idx):
127
+
128
+ name = self.data_names[idx].item()
129
+
130
+ # post_list = os.listdir(self.post_path)
131
+ # pre_list = os.listdir(self.pre_path)
132
+ # target_list = os.listdir(self.target_path)
133
+
134
+
135
+ # post_list.sort()
136
+ # pre_list.sort()
137
+ # target_list.sort()
138
+
139
+
140
+
141
+ pre_image_path = os.path.join(self.pre_path,name )
142
+ post_image_path = os.path.join(self.post_path,name )
143
+ target_path = os.path.join(self.target_path, name)
144
+
145
+
146
+
147
+
148
+
149
+ pre_image = Image.open(pre_image_path)
150
+ post_image = Image.open(post_image_path)
151
+ target_image = Image.open(target_path)
152
+
153
+
154
+ pre_image = tfms_normal(pre_image)
155
+ post_image = tfms_normal(post_image)
156
+
157
+ target_image = torch.tensor(np.array(tfms_target(target_image)) / 255).long()
158
+
159
+
160
+
161
+
162
+ return pre_image, post_image, target_image
163
+
164
+
165
+
166
+
167
+
168
+
169
+
170
+
171
+
172
+
173
+
174
+
175
+
176
+
177
+
178
+
main.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import sys
4
+ import os
5
+ import argparse
6
+ from data_maker.data_provider import Data_provider_levir, Data_provider_SYSU, Data_provider_WHU
7
+ import matplotlib.pyplot as plt
8
+ from torch.utils.data import random_split
9
+ from torch.utils.data import DataLoader
10
+ import random
11
+ import numpy as np
12
+ from method.Model import MambaCSSMUnet
13
+ import copy
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.nn.modules.padding import ReplicationPad2d
18
+ from utils.metrics.ev import Evaluator
19
+ from utils.loss.L import lovasz_softmax
20
+ import time
21
+
22
+
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser(description='Change Detection Training Script')
25
+
26
+ # Dataset arguments
27
+ parser.add_argument('--dataset', type=str, required=True,
28
+ choices=['levir', 'sysu', 'whu'],
29
+ help='Dataset to use: levir, sysu, or whu')
30
+ parser.add_argument('--train_path', type=str, required=True,
31
+ help='Path to training data (for WHU: main data directory)')
32
+ parser.add_argument('--test_path', type=str, default=None,
33
+ help='Path to test data (not used for WHU dataset)')
34
+ parser.add_argument('--val_path', type=str, default=None,
35
+ help='Path to validation data (not used for WHU dataset)')
36
+
37
+ # WHU-CD specific arguments
38
+ parser.add_argument('--train_txt', type=str, default=None,
39
+ help='Text file for WHU-CD training data (required for WHU dataset)')
40
+ parser.add_argument('--test_txt', type=str, default=None,
41
+ help='Text file for WHU-CD test data (required for WHU dataset)')
42
+ parser.add_argument('--val_txt', type=str, default=None,
43
+ help='Text file for WHU-CD validation data (required for WHU dataset)')
44
+
45
+ # Training hyperparameters
46
+ parser.add_argument('--batch_size', type=int, default=64,
47
+ help='Batch size for training (default: 64)')
48
+ parser.add_argument('--epochs', type=int, default=50,
49
+ help='Number of training epochs (default: 50)')
50
+ parser.add_argument('--lr', type=float, default=1e-3,
51
+ help='Learning rate (default: 0.001)')
52
+ parser.add_argument('--step_size', type=int, default=10,
53
+ help='Step size for learning rate scheduler (default: 10)')
54
+
55
+ # Model saving
56
+ parser.add_argument('--save_dir', type=str, default='./checkpoints',
57
+ help='Directory to save model checkpoints (default: ./checkpoints)')
58
+ parser.add_argument('--model_name', type=str, default='best_model.pth',
59
+ help='Name for saved model file (default: best_model.pth)')
60
+
61
+ # Other settings
62
+ parser.add_argument('--seed', type=int, default=42,
63
+ help='Random seed (default: 42)')
64
+ parser.add_argument('--num_workers', type=int, default=4,
65
+ help='Number of data loading workers (default: 4)')
66
+
67
+ return parser.parse_args()
68
+
69
+
70
+ def set_seed(seed=42):
71
+ random.seed(seed)
72
+ np.random.seed(seed)
73
+ torch.manual_seed(seed)
74
+ torch.cuda.manual_seed_all(seed)
75
+
76
+
77
+ def get_data_provider(dataset_name):
78
+ """Return the appropriate data provider class based on dataset name"""
79
+ providers = {
80
+ 'levir': Data_provider_levir,
81
+ 'sysu': Data_provider_SYSU,
82
+ 'whu': Data_provider_WHU
83
+ }
84
+ return providers[dataset_name]
85
+
86
+
87
+ def seed_worker(worker_id):
88
+ worker_seed = 42
89
+ np.random.seed(worker_seed)
90
+ random.seed(worker_seed)
91
+
92
+
93
+ def train(model, data, loss_ce, opt, device, train_list):
94
+ model.train()
95
+ size = len(data.dataset)
96
+
97
+ for b, (pre, post, target) in enumerate(data):
98
+ pre, post, target = pre.to(device), post.to(device), target.to(device)
99
+
100
+ y_pred = model(pre, post)
101
+
102
+ loss = loss_ce(y_pred, target) + lovasz_softmax(F.softmax(y_pred, dim=1), target, ignore=255)
103
+
104
+ opt.zero_grad()
105
+ loss.backward()
106
+ opt.step()
107
+
108
+ train_list.append(loss.item())
109
+
110
+ print(f"loss:{loss.item():.4f} [{b * len(pre)} | {size}]")
111
+
112
+
113
+ def test(model, data, loss_ce, device, evaluator, val_list):
114
+ model.eval()
115
+ size = len(data.dataset)
116
+ num_batch = len(data)
117
+ test_loss = 0
118
+
119
+ evaluator.reset()
120
+
121
+ with torch.no_grad():
122
+ for pre, post, target in data:
123
+ pre, post, target = pre.to(device), post.to(device), target.to(device)
124
+
125
+ y_pred = model(pre, post)
126
+ test_loss += loss_ce(y_pred, target).item()
127
+ output_clf = y_pred.data.cpu().numpy()
128
+ output_clf = np.argmax(output_clf, axis=1)
129
+ labels_clf = target.cpu().numpy()
130
+
131
+ evaluator.add_batch(labels_clf, output_clf)
132
+
133
+ test_loss /= num_batch
134
+ val_list.append(test_loss)
135
+ print(f"Validation Loss: {test_loss:.4f}")
136
+ print(f"IoU: {evaluator.Intersection_over_Union()}")
137
+ print(f"Confusion Matrix:\n{evaluator.confusion_matrix}")
138
+ return np.array(evaluator.Intersection_over_Union()).mean()
139
+
140
+
141
+ def main():
142
+ args = parse_args()
143
+
144
+ # Validate dataset requirements
145
+ if args.dataset == 'whu':
146
+ if not all([args.train_txt, args.test_txt, args.val_txt]):
147
+ print("Error: WHU dataset requires --train_txt, --test_txt, and --val_txt arguments")
148
+ sys.exit(1)
149
+ else:
150
+ if not all([args.test_path, args.val_path]):
151
+ print(f"Error: {args.dataset.upper()} dataset requires --train_path, --test_path, and --val_path arguments")
152
+ sys.exit(1)
153
+
154
+ # Set seed
155
+ set_seed(args.seed)
156
+ torch.backends.cudnn.deterministic = True
157
+ torch.backends.cudnn.benchmark = False
158
+
159
+ # Setup device
160
+ if torch.cuda.is_available():
161
+ device = torch.device("cuda")
162
+ print("Using CUDA")
163
+ else:
164
+ device = torch.device("cpu")
165
+ print("Using CPU")
166
+
167
+ # Create save directory
168
+ os.makedirs(args.save_dir, exist_ok=True)
169
+
170
+ # Load dataset
171
+ print(f"\nLoading {args.dataset.upper()} dataset...")
172
+ DataProvider = get_data_provider(args.dataset)
173
+
174
+ if args.dataset == 'whu':
175
+ # WHU uses single data path with different text files
176
+ train_ds = DataProvider(args.train_path, args.train_txt)
177
+ test_ds = DataProvider(args.train_path, args.test_txt)
178
+ val_ds = DataProvider(args.train_path, args.val_txt)
179
+ else:
180
+ # LEVIR and SYSU use separate paths
181
+ train_ds = DataProvider(args.train_path)
182
+ test_ds = DataProvider(args.test_path)
183
+ val_ds = DataProvider(args.val_path)
184
+
185
+ # Create data loaders
186
+ train_dl = DataLoader(dataset=train_ds, batch_size=args.batch_size,
187
+ shuffle=True, num_workers=args.num_workers,
188
+ worker_init_fn=seed_worker)
189
+ val_dl = DataLoader(dataset=val_ds, batch_size=args.batch_size,
190
+ shuffle=False, num_workers=1,
191
+ worker_init_fn=seed_worker)
192
+ test_dl = DataLoader(dataset=test_ds, batch_size=args.batch_size,
193
+ shuffle=False, num_workers=1,
194
+ worker_init_fn=seed_worker)
195
+
196
+ # Initialize model
197
+ print("\nInitializing model...")
198
+ model = MambaCSSMUnet().to(device)
199
+
200
+ # Define loss and optimizer
201
+ loss_ce = nn.CrossEntropyLoss()
202
+ opt = torch.optim.Adam(params=model.parameters(), lr=args.lr)
203
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer=opt, step_size=args.step_size)
204
+
205
+ # Training setup
206
+ train_list = []
207
+ val_list = []
208
+ evaluator = Evaluator(num_class=2)
209
+ best_val_iou = 0.0
210
+ best_model_weight = None
211
+
212
+ # Training loop
213
+ print(f"\nStarting training for {args.epochs} epochs...")
214
+ print("="*60)
215
+
216
+ for e in range(args.epochs):
217
+ print(f"\nEpoch: {e+1}/{args.epochs}")
218
+ t1 = time.time()
219
+
220
+ train(model, train_dl, loss_ce, opt, device, train_list)
221
+
222
+ val_iou = test(model, val_dl, loss_ce, device, evaluator, val_list)
223
+
224
+ if val_iou > best_val_iou:
225
+ print(f"✓ Best model updated! IoU improved from {best_val_iou:.4f} to {val_iou:.4f}")
226
+ best_val_iou = val_iou
227
+ best_model_weight = copy.deepcopy(model.state_dict())
228
+
229
+ # Save best model
230
+ save_path = os.path.join(args.save_dir, args.model_name)
231
+ torch.save(best_model_weight, save_path)
232
+ print(f"Model saved to {save_path}")
233
+
234
+ scheduler.step()
235
+ print(f"Learning Rate: {scheduler.get_last_lr()}")
236
+
237
+ t2 = time.time()
238
+ print(f"Epoch Time: {t2 - t1:.2f} seconds")
239
+ print("-"*60)
240
+
241
+ print("\n" + "="*60)
242
+ print(f"Training completed! Best IoU: {best_val_iou:.4f}")
243
+ print("="*60)
244
+
245
+
246
+ if __name__ == "__main__":
247
+ main()
method/CSSM.ipynb ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from Mamba import Mamba\n",
10
+ "import torch"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 2,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "x = torch.rand(size = (1,5,16))\n",
20
+ "\n",
21
+ "num_layers = 5\n",
22
+ "d_model = 16\n",
23
+ "d_state = 16\n",
24
+ "d_conv = 4\n",
25
+ "\n",
26
+ "mamba = Mamba(num_layers=num_layers,d_model=d_model, d_conv=d_conv, d_state=d_state)"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 3,
32
+ "metadata": {},
33
+ "outputs": [
34
+ {
35
+ "name": "stdout",
36
+ "output_type": "stream",
37
+ "text": [
38
+ "torch.Size([1, 5, 32, 16])\n",
39
+ "torch.Size([1, 5, 32, 16])\n",
40
+ "torch.Size([1, 5, 32, 16])\n",
41
+ "torch.Size([1, 5, 32, 16])\n",
42
+ "torch.Size([1, 5, 32, 16])\n"
43
+ ]
44
+ }
45
+ ],
46
+ "source": [
47
+ "y1,y2 = mamba(x,x)"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 4,
53
+ "metadata": {},
54
+ "outputs": [
55
+ {
56
+ "data": {
57
+ "text/plain": [
58
+ "torch.Size([1, 5, 16])"
59
+ ]
60
+ },
61
+ "execution_count": 4,
62
+ "metadata": {},
63
+ "output_type": "execute_result"
64
+ }
65
+ ],
66
+ "source": [
67
+ "y2.shape"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 7,
73
+ "metadata": {},
74
+ "outputs": [
75
+ {
76
+ "ename": "TypeError",
77
+ "evalue": "include_paths() got an unexpected keyword argument 'cuda'",
78
+ "output_type": "error",
79
+ "traceback": [
80
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
81
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
82
+ "Cell \u001b[0;32mIn[7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mxlstm\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 2\u001b[0m xLSTMBlockStack,\n\u001b[1;32m 3\u001b[0m xLSTMBlockStackConfig,\n\u001b[1;32m 4\u001b[0m mLSTMBlockConfig,\n\u001b[1;32m 5\u001b[0m mLSTMLayerConfig,\n\u001b[1;32m 6\u001b[0m sLSTMBlockConfig,\n\u001b[1;32m 7\u001b[0m sLSTMLayerConfig,\n\u001b[1;32m 8\u001b[0m FeedForwardConfig,\n\u001b[1;32m 9\u001b[0m )\n\u001b[1;32m 11\u001b[0m cfg \u001b[38;5;241m=\u001b[39m xLSTMBlockStackConfig(\n\u001b[1;32m 12\u001b[0m mlstm_block\u001b[38;5;241m=\u001b[39mmLSTMBlockConfig(\n\u001b[1;32m 13\u001b[0m mlstm\u001b[38;5;241m=\u001b[39mmLSTMLayerConfig(\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 30\u001b[0m \n\u001b[1;32m 31\u001b[0m )\n\u001b[1;32m 33\u001b[0m xlstm_stack \u001b[38;5;241m=\u001b[39m xLSTMBlockStack(cfg)\n",
83
+ "File \u001b[0;32m~/anaconda3/envs/CDDD/lib/python3.13/site-packages/xlstm/__init__.py:3\u001b[0m\n\u001b[1;32m 1\u001b[0m __version__ \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m2.0.2\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mblocks\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmlstm\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mblock\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m mLSTMBlock, mLSTMBlockConfig\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mblocks\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmlstm\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlayer\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m mLSTMLayer, mLSTMLayerConfig\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mblocks\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mslstm\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mblock\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m sLSTMBlock, sLSTMBlockConfig\n",
84
+ "File \u001b[0;32m~/anaconda3/envs/CDDD/lib/python3.13/site-packages/xlstm/blocks/mlstm/block.py:5\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Copyright (c) NXAI GmbH and its affiliates 2024\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m# Maximilian Beck\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mdataclasses\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m dataclass, field\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mxlstm_block\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m xLSTMBlock, xLSTMBlockConfig\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlayer\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m mLSTMLayerConfig\n\u001b[1;32m 9\u001b[0m \u001b[38;5;129m@dataclass\u001b[39m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mmLSTMBlockConfig\u001b[39;00m:\n",
85
+ "File \u001b[0;32m~/anaconda3/envs/CDDD/lib/python3.13/site-packages/xlstm/blocks/xlstm_block.py:12\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcomponents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mln\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m LayerNorm\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmlstm\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlayer\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m mLSTMLayer, mLSTMLayerConfig\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mslstm\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlayer\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m sLSTMLayer, sLSTMLayerConfig\n\u001b[1;32m 16\u001b[0m \u001b[38;5;129m@dataclass\u001b[39m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mxLSTMBlockConfig\u001b[39;00m:\n\u001b[1;32m 18\u001b[0m mlstm: Optional[mLSTMLayerConfig] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
86
+ "File \u001b[0;32m~/anaconda3/envs/CDDD/lib/python3.13/site-packages/xlstm/blocks/slstm/layer.py:15\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcomponents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minit\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m small_init_init_\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m nn\n\u001b[0;32m---> 15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcell\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m sLSTMCell, sLSTMCellConfig\n\u001b[1;32m 18\u001b[0m \u001b[38;5;129m@dataclass\u001b[39m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01msLSTMLayerConfig\u001b[39;00m(sLSTMCellConfig):\n\u001b[1;32m 20\u001b[0m embedding_dim: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n",
87
+ "File \u001b[0;32m~/anaconda3/envs/CDDD/lib/python3.13/site-packages/xlstm/blocks/slstm/cell.py:12\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mautograd\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunction\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m once_differentiable\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcuda_init\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m load\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msrc\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mvanilla\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 14\u001b[0m slstm_forward,\n\u001b[1;32m 15\u001b[0m slstm_forward_step,\n\u001b[1;32m 16\u001b[0m slstm_pointwise_function_registry,\n\u001b[1;32m 17\u001b[0m )\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcomponents\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutil\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m conditional_decorator, round_to_multiple, ParameterProxy\n",
88
+ "File \u001b[0;32m~/anaconda3/envs/CDDD/lib/python3.13/site-packages/xlstm/blocks/slstm/src/cuda_init.py:30\u001b[0m\n\u001b[1;32m 27\u001b[0m curdir \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mdirname(\u001b[38;5;18m__file__\u001b[39m)\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_available():\n\u001b[0;32m---> 30\u001b[0m os\u001b[38;5;241m.\u001b[39menviron[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCUDA_LIB\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39msplit(\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mutils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcpp_extension\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minclude_paths\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcuda\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlib\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mload\u001b[39m(\u001b[38;5;241m*\u001b[39m, name, sources, extra_cflags\u001b[38;5;241m=\u001b[39m(), extra_cuda_cflags\u001b[38;5;241m=\u001b[39m(), \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 34\u001b[0m suffix \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
89
+ "\u001b[0;31mTypeError\u001b[0m: include_paths() got an unexpected keyword argument 'cuda'"
90
+ ]
91
+ }
92
+ ],
93
+ "source": [
94
+ "from xlstm import (\n",
95
+ " xLSTMBlockStack,\n",
96
+ " xLSTMBlockStackConfig,\n",
97
+ " mLSTMBlockConfig,\n",
98
+ " mLSTMLayerConfig,\n",
99
+ " sLSTMBlockConfig,\n",
100
+ " sLSTMLayerConfig,\n",
101
+ " FeedForwardConfig,\n",
102
+ ")\n",
103
+ "\n",
104
+ "cfg = xLSTMBlockStackConfig(\n",
105
+ " mlstm_block=mLSTMBlockConfig(\n",
106
+ " mlstm=mLSTMLayerConfig(\n",
107
+ " conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4\n",
108
+ " )\n",
109
+ " ),\n",
110
+ " slstm_block=sLSTMBlockConfig(\n",
111
+ " slstm=sLSTMLayerConfig(\n",
112
+ " # backend=,\n",
113
+ " num_heads=4,\n",
114
+ " conv1d_kernel_size=4,\n",
115
+ " bias_init=\"powerlaw_blockdependent\",\n",
116
+ " ),\n",
117
+ " feedforward=FeedForwardConfig(proj_factor=1.3, act_fn=\"gelu\"),\n",
118
+ " ),\n",
119
+ " context_length=256,\n",
120
+ " num_blocks=7,\n",
121
+ " embedding_dim=128,\n",
122
+ " slstm_at=[1],\n",
123
+ "\n",
124
+ ")\n",
125
+ "\n",
126
+ "xlstm_stack = xLSTMBlockStack(cfg)\n",
127
+ "\n",
128
+ "x = torch.randn(4, 256, 128).to(torch.device(\"cuda\"))\n",
129
+ "xlstm_stack = xlstm_stack.to(torch.device(\"cuda\"))"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": 16,
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "import os \n",
139
+ "import pandas as pd\n",
140
+ "import numpy as np\n",
141
+ "\n",
142
+ "t = os.path.join(\"/media/elman/backup/DG_CD/WHU-CD-256/list/train.txt\")"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": 24,
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "f = np.array((pd.read_csv(t,names=[\"ttt\"])))"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": null,
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": []
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": 27,
164
+ "metadata": {},
165
+ "outputs": [
166
+ {
167
+ "data": {
168
+ "text/plain": [
169
+ "'whucd_00267.png'"
170
+ ]
171
+ },
172
+ "execution_count": 27,
173
+ "metadata": {},
174
+ "output_type": "execute_result"
175
+ }
176
+ ],
177
+ "source": [
178
+ "f[1].item()"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": null,
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": []
187
+ }
188
+ ],
189
+ "metadata": {
190
+ "kernelspec": {
191
+ "display_name": "CDDD",
192
+ "language": "python",
193
+ "name": "python3"
194
+ },
195
+ "language_info": {
196
+ "codemirror_mode": {
197
+ "name": "ipython",
198
+ "version": 3
199
+ },
200
+ "file_extension": ".py",
201
+ "mimetype": "text/x-python",
202
+ "name": "python",
203
+ "nbconvert_exporter": "python",
204
+ "pygments_lexer": "ipython3",
205
+ "version": "3.13.2"
206
+ }
207
+ },
208
+ "nbformat": 4,
209
+ "nbformat_minor": 2
210
+ }
method/MambaCSSM.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import json
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from dataclasses import dataclass
8
+ from einops import rearrange, repeat, einsum
9
+ from typing import Union
10
+
11
+
12
+
13
+
14
+ @dataclass
15
+ class ModelArgs:
16
+ d_model: int
17
+ n_layer: int
18
+ vocab_size: int
19
+ d_state: int = 16
20
+ expand: int = 2
21
+ dt_rank: Union[int, str] = 'auto'
22
+ d_conv: int = 4
23
+ pad_vocab_size_multiple: int = 8
24
+ conv_bias: bool = True
25
+ bias: bool = False
26
+
27
+ def __post_init__(self):
28
+ self.d_inner = int(self.expand * self.d_model)
29
+
30
+ if self.dt_rank == 'auto':
31
+ self.dt_rank = math.ceil(self.d_model / 16)
32
+
33
+ if self.vocab_size % self.pad_vocab_size_multiple != 0:
34
+ self.vocab_size += (self.pad_vocab_size_multiple
35
+ - self.vocab_size % self.pad_vocab_size_multiple)
36
+
37
+
38
+
39
+
40
+
41
+
42
+ class MambaBlock_CD(nn.Module):
43
+ def __init__(self, d_model,d_conv, d_state, bias = True, conv_bias = True):
44
+ """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
45
+ super().__init__()
46
+ # self.args = args
47
+
48
+
49
+ self.norm = RMSNorm(d_model=d_model)
50
+
51
+
52
+ self.d_inner = 2 * d_model
53
+ self.dt_rank = math.ceil(d_model / 16)
54
+
55
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=bias)
56
+
57
+ self.mlp_1 = nn.Linear(self.d_inner, d_model)
58
+ self.mlp_2 = nn.Linear(self.d_inner, d_model)
59
+
60
+ self.conv1d = nn.Conv1d(
61
+ in_channels=self.d_inner,
62
+ out_channels=self.d_inner,
63
+ bias=conv_bias,
64
+ kernel_size=d_conv,
65
+ groups=self.d_inner,
66
+ padding=d_conv - 1,
67
+ )
68
+
69
+ # x_proj takes in `x` and outputs the input-specific Δ, B, C
70
+ self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)
71
+
72
+ # dt_proj projects Δ from dt_rank to d_in
73
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
74
+
75
+ A = repeat(torch.arange(1, d_state + 1), 'n -> d n', d=self.d_inner)
76
+ self.A_log = nn.Parameter(torch.log(A))
77
+ self.D = nn.Parameter(torch.ones(self.d_inner))
78
+ self.D_p = nn.Parameter(torch.ones(self.d_inner))
79
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
80
+
81
+
82
+ def forward(self, t1,t2):
83
+
84
+ ee1 = t1
85
+ ee2 = t2
86
+
87
+ (b, l, d) = t1.shape
88
+ t1 = self.norm(t1)
89
+
90
+ t1_and_res = self.in_proj(t1) # shape (b, l, 2 * d_in)
91
+ (t1, res1) = t1_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)
92
+
93
+ t1 = rearrange(t1, 'b l d_in -> b d_in l')
94
+ t1 = self.conv1d(t1)[:, :, :l]
95
+ t1 = rearrange(t1, 'b d_in l -> b l d_in')
96
+
97
+ t1 = F.silu(t1)
98
+
99
+
100
+ (b, l, d) = t2.shape
101
+ t2 = self.norm(t2)
102
+
103
+ t2_and_res = self.in_proj(t2) # shape (b, l, 2 * d_in)
104
+ (t2, res2) = t2_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)
105
+
106
+ t2 = rearrange(t2, 'b l d_in -> b d_in l')
107
+ t2 = self.conv1d(t2)[:, :, :l]
108
+ t2 = rearrange(t2, 'b d_in l -> b l d_in')
109
+
110
+ t2 = F.silu(t2)
111
+
112
+ y1,y2 = self.cssm(t1,t2)
113
+
114
+ y1 = y1 * F.silu(res1)
115
+ y2 = y2 * F.silu(res2)
116
+
117
+ output1 = self.out_proj(y1)
118
+ output2 = self.out_proj(y2)
119
+
120
+
121
+
122
+ return output1 + ee1, output2 + ee2
123
+
124
+
125
+ def cssm(self, t1, t2):
126
+
127
+ (d_in, n) = self.A_log.shape
128
+
129
+
130
+ A = -torch.exp(self.A_log.float()) # shape (d_in, n)
131
+ D = self.D.float()
132
+
133
+ t1_dbl = self.x_proj(t1) # (b, l, dt_rank + 2*n)
134
+
135
+ (delta, B, C) = t1_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
136
+ delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
137
+
138
+
139
+ A_prim = -torch.exp(self.A_log.float()) # shape (d_in, n)
140
+ D_prim = self.D_p.float()
141
+
142
+ t2_dbl = self.x_proj(t2) # (b, l, dt_rank + 2*n)
143
+
144
+ (delta, B_prim, C_prim) = t2_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
145
+ delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
146
+
147
+ y = self.selective_scan(t1,t2, delta, A, B, C, D, A_prim, B_prim, C_prim, D_prim) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
148
+
149
+ return y
150
+
151
+
152
+ def selective_scan(self, t1,t2, delta, A, B, C, D, A_prim, B_prim, C_prim, D_prim):
153
+
154
+ (b, l, d_in) = t1.shape
155
+ n = A.shape[1]
156
+
157
+ deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
158
+ deltaB_u = einsum(delta, B, t1, 'b l d_in, b l n, b l d_in -> b l d_in n')
159
+ deltaB_u_prim = einsum(delta, B_prim, t2, 'b l d_in, b l n, b l d_in -> b l d_in n')
160
+
161
+ x = torch.zeros((b, d_in, n), device=deltaA.device)
162
+ ys = []
163
+ for i in range(l):
164
+ x = deltaA[:, i] * x + torch.abs(deltaB_u[:, i] - deltaB_u_prim[:,i])
165
+ y1 = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
166
+ ys.append(y1)
167
+ y1 = torch.stack(ys, dim=1) # shape (b, l, d_in)
168
+
169
+ y1 = y1 + t1 * D
170
+
171
+
172
+ (b, l, d_in) = t2.shape
173
+ n = A_prim.shape[1]
174
+
175
+ deltaA_prim = torch.exp(einsum(delta, A_prim, 'b l d_in, d_in n -> b l d_in n'))
176
+ # deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
177
+
178
+ x = torch.zeros((b, d_in, n), device=deltaA.device)
179
+ ys = []
180
+ for i in range(l):
181
+ x = deltaA_prim[:, i] * x + torch.abs(deltaB_u[:, i] - deltaB_u_prim[:,i])
182
+ y2 = einsum(x, C_prim[:, i, :], 'b d_in n, b n -> b d_in')
183
+ ys.append(y2)
184
+ y2 = torch.stack(ys, dim=1) # shape (b, l, d_in)
185
+
186
+ y2 = y2 + t2 * D_prim
187
+
188
+ return y1 ,y2
189
+
190
+
191
+
192
+
193
+
194
+ class MambaCSSM(nn.Module):
195
+
196
+ def __init__(self, num_layers, d_model,d_conv, d_state, bias = True, conv_bias = True ):
197
+ super().__init__()
198
+
199
+ self.layers = nn.ModuleList([MambaBlock_CD(d_model,d_conv, d_state, bias = True, conv_bias = True) for _ in range(num_layers)])
200
+
201
+
202
+ def forward(self, t1,t2):
203
+
204
+ for layer in self.layers:
205
+ t1,t2 = layer(t1,t2)
206
+
207
+ return t1,t2
208
+
209
+
210
+
211
+
212
+
213
+
214
+
215
+ class MambaBlock(nn.Module):
216
+ def __init__(self, d_model,d_conv, d_state, bias = True, conv_bias = True):
217
+ """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
218
+ super().__init__()
219
+ # self.args = args
220
+
221
+
222
+ self.d_inner = 2 * d_model
223
+ self.dt_rank = math.ceil(d_model / 16)
224
+
225
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=bias)
226
+
227
+ self.conv1d = nn.Conv1d(
228
+ in_channels=self.d_inner,
229
+ out_channels=self.d_inner,
230
+ bias=conv_bias,
231
+ kernel_size=d_conv,
232
+ groups=self.d_inner,
233
+ padding=d_conv - 1,
234
+ )
235
+
236
+ # x_proj takes in `x` and outputs the input-specific Δ, B, C
237
+ self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)
238
+
239
+ # dt_proj projects Δ from dt_rank to d_in
240
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
241
+
242
+ A = repeat(torch.arange(1, d_state + 1), 'n -> d n', d=self.d_inner)
243
+ self.A_log = nn.Parameter(torch.log(A))
244
+ self.D = nn.Parameter(torch.ones(self.d_inner))
245
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
246
+
247
+
248
+ def forward(self, x):
249
+ """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
250
+
251
+ Args:
252
+ x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
253
+
254
+ Returns:
255
+ output: shape (b, l, d)
256
+
257
+ Official Implementation:
258
+ class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
259
+ mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
260
+
261
+ """
262
+ (b, l, d) = x.shape
263
+
264
+ x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
265
+ (x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)
266
+
267
+ x = rearrange(x, 'b l d_in -> b d_in l')
268
+ x = self.conv1d(x)[:, :, :l]
269
+ x = rearrange(x, 'b d_in l -> b l d_in')
270
+
271
+ x = F.silu(x)
272
+
273
+ y = self.ssm(x)
274
+
275
+ y = y * F.silu(res)
276
+
277
+ output = self.out_proj(y)
278
+
279
+ return output
280
+
281
+
282
+ def ssm(self, x):
283
+ """Runs the SSM. See:
284
+ - Algorithm 2 in Section 3.2 in the Mamba paper [1]
285
+ - run_SSM(A, B, C, u) in The Annotated S4 [2]
286
+
287
+ Args:
288
+ x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
289
+
290
+ Returns:
291
+ output: shape (b, l, d_in)
292
+
293
+ Official Implementation:
294
+ mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
295
+
296
+ """
297
+ (d_in, n) = self.A_log.shape
298
+
299
+ # Compute ∆ A B C D, the state space parameters.
300
+ # A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
301
+ # ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
302
+ # and is why Mamba is called **selective** state spaces)
303
+
304
+ A = -torch.exp(self.A_log.float()) # shape (d_in, n)
305
+ D = self.D.float()
306
+
307
+ x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
308
+
309
+ (delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
310
+ delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
311
+
312
+ y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
313
+
314
+ return y
315
+
316
+
317
+ def selective_scan(self, u, delta, A, B, C, D):
318
+ """Does selective scan algorithm. See:
319
+ - Section 2 State Space Models in the Mamba paper [1]
320
+ - Algorithm 2 in Section 3.2 in the Mamba paper [1]
321
+ - run_SSM(A, B, C, u) in The Annotated S4 [2]
322
+
323
+ This is the classic discrete state space formula:
324
+ x(t + 1) = Ax(t) + Bu(t)
325
+ y(t) = Cx(t) + Du(t)
326
+ except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
327
+
328
+ Args:
329
+ u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
330
+ delta: shape (b, l, d_in)
331
+ A: shape (d_in, n)
332
+ B: shape (b, l, n)
333
+ C: shape (b, l, n)
334
+ D: shape (d_in,)
335
+
336
+ Returns:
337
+ output: shape (b, l, d_in)
338
+
339
+ Official Implementation:
340
+ selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
341
+ Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
342
+
343
+ """
344
+ (b, l, d_in) = u.shape
345
+ n = A.shape[1]
346
+
347
+ # Discretize continuous parameters (A, B)
348
+ # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
349
+ # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
350
+ # "A is the more important term and the performance doesn't change much with the simplification on B"
351
+ deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
352
+ deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
353
+
354
+ # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
355
+ # Note that the below is sequential, while the official implementation does a much faster parallel scan that
356
+ # is additionally hardware-aware (like FlashAttention).
357
+ x = torch.zeros((b, d_in, n), device=deltaA.device)
358
+ ys = []
359
+ for i in range(l):
360
+ x = deltaA[:, i] * x + deltaB_u[:, i]
361
+ y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
362
+ ys.append(y)
363
+ y = torch.stack(ys, dim=1) # shape (b, l, d_in)
364
+
365
+ y = y + u * D
366
+
367
+ return y
368
+
369
+
370
+ class RMSNorm(nn.Module):
371
+ def __init__(self,
372
+ d_model: int,
373
+ eps: float = 1e-5):
374
+ super().__init__()
375
+ self.eps = eps
376
+ self.weight = nn.Parameter(torch.ones(d_model))
377
+
378
+
379
+ def forward(self, x):
380
+ output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
381
+
382
+ return output
method/Model.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from method.MambaCSSM import MambaCSSM
4
+
5
+ class MambaCSSMUnet(nn.Module):
6
+
7
+ def __init__(self, output_classes = 2):
8
+ super(MambaCSSMUnet, self).__init__()
9
+
10
+ #### Encoder Conv
11
+ self.conv_block_1 = nn.Sequential(
12
+ nn.Conv2d(6, 16, 3, 1, padding=1),
13
+ nn.BatchNorm2d(16),
14
+ nn.ReLU(),
15
+ nn.Conv2d(16, 16, 3, 1, padding=1),
16
+ nn.BatchNorm2d(16),
17
+ nn.ReLU()
18
+ )
19
+
20
+ self.mp_block_1 = nn.MaxPool2d(2, 2, return_indices=True)
21
+
22
+ self.conv_block_2 = nn.Sequential(
23
+ nn.Conv2d(16, 32, 3, 1, padding=1),
24
+ nn.BatchNorm2d(32),
25
+ nn.ReLU(),
26
+ nn.Conv2d(32, 32, 3, 1, padding=1),
27
+ nn.BatchNorm2d(32),
28
+ nn.ReLU()
29
+ )
30
+
31
+ self.mp_block_2 = nn.MaxPool2d(2, 2, return_indices=True)
32
+
33
+ self.conv_block_3 = nn.Sequential(
34
+ nn.Conv2d(32, 64, 3, 1, padding=1),
35
+ nn.BatchNorm2d(64),
36
+ nn.ReLU(),
37
+ nn.Conv2d(64, 64, 3, 1, padding=1),
38
+ nn.BatchNorm2d(64),
39
+ nn.ReLU()
40
+ )
41
+
42
+ self.mp_block_3 = nn.MaxPool2d(2, 2, return_indices=True)
43
+
44
+ self.conv_block_4 = nn.Sequential(
45
+ nn.Conv2d(64, 128, 3, 1, padding=1),
46
+ nn.BatchNorm2d(128),
47
+ nn.ReLU(),
48
+ nn.Conv2d(128, 128, 3, 1, padding=1),
49
+ nn.BatchNorm2d(128),
50
+ nn.ReLU()
51
+ )
52
+
53
+ self.mp_block_4 = nn.MaxPool2d(2, 2, return_indices=True)
54
+
55
+ #### Mamba
56
+
57
+
58
+ self.mamba = MambaCSSM(num_layers=4, d_model=256,d_conv=4, d_state=16)
59
+
60
+
61
+ #### Decoder Deconv
62
+ self.mpu_block_4 = nn.MaxUnpool2d(2, 2)
63
+ self.conv_4 = nn.Sequential(
64
+ nn.Conv2d(256, 128, 3, 1, padding=1),
65
+ nn.ReLU()
66
+ )
67
+ self.deconv_4_block = nn.Sequential(
68
+ nn.ConvTranspose2d(128, 64, 3, 1, padding=1),
69
+ nn.ReLU(),
70
+ nn.ConvTranspose2d(64, 64, 3, 1, padding=1),
71
+ nn.ReLU()
72
+ )
73
+
74
+ self.mpu_block_3 = nn.MaxUnpool2d(2, 2)
75
+
76
+ self.conv_3 = nn.Sequential(
77
+ nn.Conv2d(128, 64, 3, 1, padding=1),
78
+ nn.ReLU()
79
+ )
80
+
81
+ self.deconv_3_block = nn.Sequential(
82
+ nn.ConvTranspose2d(64, 32, 3, 1, padding=1),
83
+ nn.ReLU(),
84
+ nn.ConvTranspose2d(32, 32, 3, 1, padding=1),
85
+ nn.ReLU()
86
+ )
87
+
88
+ self.mpu_block_2 = nn.MaxUnpool2d(2, 2)
89
+
90
+ self.conv_2 = nn.Sequential(
91
+ nn.Conv2d(64, 32, 3, 1, padding=1),
92
+ nn.ReLU()
93
+ )
94
+
95
+ self.deconv_2_block = nn.Sequential(
96
+ nn.ConvTranspose2d(32, 16, 3, 1, padding=1),
97
+ nn.ReLU(),
98
+ nn.ConvTranspose2d(16, 16, 3, 1, padding=1),
99
+ nn.ReLU()
100
+ )
101
+
102
+ self.mpu_block_1 = nn.MaxUnpool2d(2, 2)
103
+
104
+ self.conv_1 = nn.Sequential(
105
+ nn.Conv2d(32, 16, 3, 1, padding=1),
106
+ nn.ReLU()
107
+ )
108
+
109
+ self.deconv_1_block = nn.Sequential(
110
+ nn.ConvTranspose2d(16, 8, 3, 1, padding=1),
111
+ nn.ReLU(),
112
+ nn.ConvTranspose2d(8, 6, 3, 1, padding=1),
113
+ nn.ReLU()
114
+ )
115
+
116
+ self.conv_final = nn.Conv2d(6, output_classes, 1, 1)
117
+
118
+
119
+ def forward(self, t1,t2):
120
+
121
+ t = torch.cat([t1,t2], dim = 1)
122
+
123
+ x1 = self.conv_block_1(t)
124
+ f1, i1 = self.mp_block_1(x1)
125
+ x2 = self.conv_block_2(f1)
126
+ f2, i2 = self.mp_block_2(x2)
127
+ x3 = self.conv_block_3(f2)
128
+ f3, i3 = self.mp_block_3(x3)
129
+ x4 = self.conv_block_4(f3)
130
+ f4, i4 = self.mp_block_4(x4)
131
+
132
+
133
+
134
+ b,c,h,w = f4.shape
135
+ f4_t1 = f4[:,:c//2, :,:]
136
+ f4_t2 = f4[:,c//2:, :,:]
137
+
138
+
139
+
140
+ # print(f4_t1.shape)
141
+ f4_t1 = f4_t1.view((-1, 64, 16*16)) # Adjusted for input size 256x256
142
+ f4_t2 = f4_t2.view((-1, 64, 16*16)) # Adjusted for input size 256x256
143
+ f5_t1,f5_t2 = self.mamba(f4_t1, f4_t2)
144
+ f5_t1 = f5_t1.view((-1, 64, 16, 16)) # Adjust the shape for further operations
145
+ f5_t2 = f5_t2.view((-1, 64, 16, 16)) # Adjust the shape for further operations
146
+
147
+ f5 = torch.cat([f5_t1, f5_t2], dim = 1)
148
+
149
+
150
+ f6 = self.mpu_block_4(f5, i4)
151
+ f7 = self.conv_4(torch.cat((x4, f6), dim=1))
152
+ f8 = self.deconv_4_block(f7)
153
+
154
+ f9 = self.mpu_block_3(f8, i3, output_size=x3.size())
155
+ f10 = self.conv_3(torch.cat((f9, x3), dim=1))
156
+ f11 = self.deconv_3_block(f10)
157
+
158
+ f12 = self.mpu_block_2(f11, i2)
159
+ f13 = self.conv_2(torch.cat((f12, x2), dim=1))
160
+
161
+ f14 = self.deconv_2_block(f13)
162
+
163
+ f15 = self.mpu_block_1(f14, i1)
164
+ f16 = self.conv_1(torch.cat((f15, x1), dim=1))
165
+ f17 = self.deconv_1_block(f16)
166
+ f18 = self.conv_final(f17)
167
+
168
+
169
+
170
+
171
+
172
+
173
+ return f18
method/__pycache__/Mamba.cpython-313.pyc ADDED
Binary file (17.3 kB). View file
 
method/__pycache__/MambaCSSM.cpython-313.pyc ADDED
Binary file (17.3 kB). View file
 
method/__pycache__/Model.cpython-313.pyc ADDED
Binary file (8.15 kB). View file
 
pre_trained_weights/LEVIR+/levir_cd_+_cssm.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:faac6be557638ff677a1f9d83b6b3c0e02c6f198e84c82b54f109f46722b341c
3
+ size 17446716
pre_trained_weights/LEVIR+/levir_layer_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b30cefff0e7c839d8e3ac056e11b480044094400a6395aa770fd8eba957ab87f
3
+ size 6183447
pre_trained_weights/LEVIR+/levir_layer_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbd45dcedf149ee139065f585bdc76ed67c1bdc5e4134fd86bd5b5fcf6d9bb93
3
+ size 8999194
pre_trained_weights/LEVIR+/levir_layer_3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8815b516932a36eb24bf41db20272d3439cf275c33bf2c265dc2892b224e43b
3
+ size 11814942
pre_trained_weights/LEVIR+/levir_layer_4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:747f89b70c81fe1bb4b7b54b30a1ba86b5263f7d35e2bb720fe00068a561bfbc
3
+ size 14630690
pre_trained_weights/LEVIR+/levir_layer_6.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e12c2629f1797cdb6b59268a4c851956660139fac0a8da1ec406cdf69f2c8ec
3
+ size 20262122
pre_trained_weights/SYSU-CD/sysu.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dcba1ce2bf614a0e357af4d973322f4714ac9890fcbc3a08a22eac8e4519342f
3
+ size 17434915
pre_trained_weights/SYSU-CD/sysu_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c6d03d55f3e770dd988585d2e4adad8c30008708836e176fac566adc4c1d442
3
+ size 17439289
pre_trained_weights/SYSU-CD/sysu_layer_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:896d763f09a5aec231565bfc21c2b464beb48496404001f4c96964bbd51e37e2
3
+ size 6183344
pre_trained_weights/SYSU-CD/sysu_layer_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecde174350c588e9cfc079e6d332d09f6bc28ab677645389034fae8f1098a878
3
+ size 8999074
pre_trained_weights/SYSU-CD/sysu_layer_3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f040144f83e4aeaf8635b9a84736d1a3a1d815b8c1bb6452b99f1dfe8e40723c
3
+ size 11814805
pre_trained_weights/SYSU-CD/sysu_layer_4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ecca67ce52f965d60c38eb7c405c35eaf69d4b9ec05503f567135011f82bd35
3
+ size 14630536
pre_trained_weights/SYSU-CD/sysu_layer_5.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e631dec5c109e5b6ffefe898604fbe6fd1cfa35f738a0d3f5c2ee61ff963db9
3
+ size 20261934
pre_trained_weights/SYSU-CD/sysu_layer_6.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82a652e3b42a83ae2b5a3614afbdf26cab9f582c8a5734a274992ea6035232ff
3
+ size 20261934
pre_trained_weights/WHU-CD/whu.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c778f52d783c306397a4643ba1b2037ea166c0d39f84bae6de3008a3eb74c96
3
+ size 17434744
pre_trained_weights/WHU-CD/whu_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d56db7dcee57e0b01adbe84b4de54e2ca5a5810c7937b4534cf117558a340b6
3
+ size 17435086
pre_trained_weights/WHU-CD/whu_layer_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b05e358ddac70c2faf410a8fd1cf62756eceaf35909645d24124e9063824b075
3
+ size 6183241
pre_trained_weights/WHU-CD/whu_layer_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b31b24f29784692d63f798ab1d878ac7fb3e327e789fd59a4c0ac8bfd16f829
3
+ size 8998954
pre_trained_weights/WHU-CD/whu_layer_3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:939034d4c1b74891fbc454d32e0f744d1244f73451598f923fd5dc77c2ba5a1a
3
+ size 11814668
pre_trained_weights/WHU-CD/whu_layer_4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25f2deaeeba1e5643519466f13b722b5d7bff0a510a5dd5f28184153719b0eea
3
+ size 14630382
pre_trained_weights/WHU-CD/whu_layer_5.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c161ddd7735e2d80adda186ab2e9854f80b02c77f4e4fe1fe788bc90d32631c
3
+ size 17446032
pre_trained_weights/WHU-CD/whu_layer_6.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a94d143f3ccafc97934c89285febbfde8ea6106ba65c6f55295ede407ba36df
3
+ size 20261746
utils/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (151 Bytes). View file
 
utils/__pycache__/imgutils.cpython-313.pyc ADDED
Binary file (4.81 kB). View file
 
utils/__pycache__/make_data.cpython-313.pyc ADDED
Binary file (7.37 kB). View file
 
utils/__pycache__/metric.cpython-313.pyc ADDED
Binary file (7.36 kB). View file
 
utils/__pycache__/utils_loss.cpython-313.pyc ADDED
Binary file (11.7 kB). View file
 
utils/loss/L.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+
3
+ import torch
4
+ from torch.autograd import Variable
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ try:
8
+ from itertools import ifilterfalse
9
+ except ImportError: # py3k
10
+ from itertools import filterfalse as ifilterfalse
11
+
12
+
13
+ def lovasz_grad(gt_sorted):
14
+ """
15
+ Computes gradient of the Lovasz extension w.r.t sorted errors
16
+ See Alg. 1 in paper
17
+ """
18
+ p = len(gt_sorted)
19
+ gts = gt_sorted.sum()
20
+ intersection = gts - gt_sorted.float().cumsum(0)
21
+ union = gts + (1 - gt_sorted).float().cumsum(0)
22
+ jaccard = 1. - intersection / union
23
+ if p > 1: # cover 1-pixel case
24
+ jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
25
+ return jaccard
26
+
27
+
28
+ def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
29
+ """
30
+ IoU for foreground class
31
+ binary: 1 foreground, 0 background
32
+ """
33
+ if not per_image:
34
+ preds, labels = (preds,), (labels,)
35
+ ious = []
36
+ for pred, label in zip(preds, labels):
37
+ intersection = ((label == 1) & (pred == 1)).sum()
38
+ union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
39
+ if not union:
40
+ iou = EMPTY
41
+ else:
42
+ iou = float(intersection) / float(union)
43
+ ious.append(iou)
44
+ iou = mean(ious) # mean accross images if per_image
45
+ return 100 * iou
46
+
47
+
48
+ def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
49
+ """
50
+ Array of IoU for each (non ignored) class
51
+ """
52
+ if not per_image:
53
+ preds, labels = (preds,), (labels,)
54
+ ious = []
55
+ for pred, label in zip(preds, labels):
56
+ iou = []
57
+ for i in range(C):
58
+ if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
59
+ intersection = ((label == i) & (pred == i)).sum()
60
+ union = ((label == i) | ((pred == i) & (label != ignore))).sum()
61
+ if not union:
62
+ iou.append(EMPTY)
63
+ else:
64
+ iou.append(float(intersection) / float(union))
65
+ ious.append(iou)
66
+ ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
67
+ return 100 * np.array(ious)
68
+
69
+
70
+ # --------------------------- BINARY LOSSES ---------------------------
71
+
72
+
73
+ def lovasz_hinge(logits, labels, per_image=True, ignore=None):
74
+ """
75
+ Binary Lovasz hinge loss
76
+ logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
77
+ labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
78
+ per_image: compute the loss per image instead of per batch
79
+ ignore: void class id
80
+ """
81
+ if per_image:
82
+ loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
83
+ for log, lab in zip(logits, labels))
84
+ else:
85
+ loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
86
+ return loss
87
+
88
+
89
+ def lovasz_hinge_flat(logits, labels):
90
+ """
91
+ Binary Lovasz hinge loss
92
+ logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
93
+ labels: [P] Tensor, binary ground truth labels (0 or 1)
94
+ ignore: label to ignore
95
+ """
96
+ if len(labels) == 0:
97
+ # only void pixels, the gradients should be 0
98
+ return logits.sum() * 0.
99
+ signs = 2. * labels.float() - 1.
100
+ errors = (1. - logits * Variable(signs))
101
+ errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
102
+ perm = perm.data
103
+ gt_sorted = labels[perm]
104
+ grad = lovasz_grad(gt_sorted)
105
+ loss = torch.dot(F.relu(errors_sorted), Variable(grad))
106
+ return loss
107
+
108
+
109
+ def flatten_binary_scores(scores, labels, ignore=None):
110
+ """
111
+ Flattens predictions in the batch (binary case)
112
+ Remove labels equal to 'ignore'
113
+ """
114
+ scores = scores.view(-1)
115
+ labels = labels.view(-1)
116
+ if ignore is None:
117
+ return scores, labels
118
+ valid = (labels != ignore)
119
+ vscores = scores[valid]
120
+ vlabels = labels[valid]
121
+ return vscores, vlabels
122
+
123
+
124
+ class StableBCELoss(torch.nn.modules.Module):
125
+ def __init__(self):
126
+ super(StableBCELoss, self).__init__()
127
+ def forward(self, input, target):
128
+ neg_abs = - input.abs()
129
+ loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
130
+ return loss.mean()
131
+
132
+
133
+ def binary_xloss(logits, labels, ignore=None):
134
+ """
135
+ Binary Cross entropy loss
136
+ logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
137
+ labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
138
+ ignore: void class id
139
+ """
140
+ logits, labels = flatten_binary_scores(logits, labels, ignore)
141
+ loss = StableBCELoss()(logits, Variable(labels.float()))
142
+ return loss
143
+
144
+
145
+ # --------------------------- MULTICLASS LOSSES ---------------------------
146
+
147
+
148
+ def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
149
+ """
150
+ Multi-class Lovasz-Softmax loss
151
+ probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
152
+ Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
153
+ labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
154
+ classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
155
+ per_image: compute the loss per image instead of per batch
156
+ ignore: void class labels
157
+ """
158
+ if per_image:
159
+ loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
160
+ for prob, lab in zip(probas, labels))
161
+ else:
162
+ loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
163
+ return loss
164
+
165
+
166
+ def lovasz_softmax_flat(probas, labels, classes='present'):
167
+ """
168
+ Multi-class Lovasz-Softmax loss
169
+ probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
170
+ labels: [P] Tensor, ground truth labels (between 0 and C - 1)
171
+ classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
172
+ """
173
+ if probas.numel() == 0:
174
+ # only void pixels, the gradients should be 0
175
+ return probas * 0.
176
+ C = probas.size(1)
177
+ losses = []
178
+ class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
179
+ for c in class_to_sum:
180
+ fg = (labels == c).float() # foreground for class c
181
+ if (classes is 'present' and fg.sum() == 0):
182
+ continue
183
+ if C == 1:
184
+ if len(classes) > 1:
185
+ raise ValueError('Sigmoid output possible only with 1 class')
186
+ class_pred = probas[:, 0]
187
+ else:
188
+ class_pred = probas[:, c]
189
+ errors = (Variable(fg) - class_pred).abs()
190
+ errors_sorted, perm = torch.sort(errors, 0, descending=True)
191
+ perm = perm.data
192
+ fg_sorted = fg[perm]
193
+ losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
194
+ return mean(losses)
195
+
196
+
197
+ def flatten_probas(probas, labels, ignore=None):
198
+ """
199
+ Flattens predictions in the batch
200
+ """
201
+ if probas.dim() == 3:
202
+ # assumes output of a sigmoid layer
203
+ B, H, W = probas.size()
204
+ probas = probas.view(B, 1, H, W)
205
+ B, C, H, W = probas.size()
206
+ probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
207
+ labels = labels.view(-1)
208
+ if ignore is None:
209
+ return probas, labels
210
+ valid = (labels != ignore)
211
+ vprobas = probas[valid.nonzero().squeeze()]
212
+ vlabels = labels[valid]
213
+ return vprobas, vlabels
214
+
215
+ def xloss(logits, labels, ignore=None):
216
+ """
217
+ Cross entropy loss
218
+ """
219
+ return F.cross_entropy(logits, Variable(labels), ignore_index=255)
220
+
221
+
222
+ # --------------------------- HELPER FUNCTIONS ---------------------------
223
+ def isnan(x):
224
+ return x != x
225
+
226
+
227
+ def mean(l, ignore_nan=False, empty=0):
228
+ """
229
+ nanmean compatible with generators.
230
+ """
231
+ l = iter(l)
232
+ if ignore_nan:
233
+ l = ifilterfalse(isnan, l)
234
+ try:
235
+ n = 1
236
+ acc = next(l)
237
+ except StopIteration:
238
+ if empty == 'raise':
239
+ raise ValueError('Empty mean')
240
+ return empty
241
+ for n, v in enumerate(l, 2):
242
+ acc += v
243
+ if n == 1:
244
+ return acc
245
+ return acc / n
utils/loss/__pycache__/L.cpython-313.pyc ADDED
Binary file (11.5 kB). View file
 
utils/metrics/__pycache__/ev.cpython-313.pyc ADDED
Binary file (7.36 kB). View file
 
utils/metrics/ev.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Evaluator(object):
5
+ def __init__(self, num_class):
6
+ self.num_class = num_class
7
+ self.confusion_matrix = np.zeros((self.num_class,) * 2, dtype=np.longlong)
8
+ self._epsilon = 1e-7
9
+
10
+ def Pixel_Accuracy(self):
11
+ Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
12
+ return Acc
13
+
14
+ def Pixel_Accuracy_Class(self):
15
+ Acc = np.diag(self.confusion_matrix) / (self.confusion_matrix.sum(axis=1) + self._epsilon)
16
+ mAcc = np.nanmean(Acc)
17
+ return mAcc, Acc
18
+
19
+ def Pixel_Precision_Rate(self):
20
+ assert self.confusion_matrix.shape[0] == 2
21
+ Pre = self.confusion_matrix[1, 1] / (self.confusion_matrix[0, 1] + self.confusion_matrix[1, 1] + self._epsilon)
22
+ return Pre
23
+
24
+ def Pixel_Recall_Rate(self):
25
+ assert self.confusion_matrix.shape[0] == 2
26
+ Rec = self.confusion_matrix[1, 1] / (self.confusion_matrix[1, 0] + self.confusion_matrix[1, 1] + self._epsilon)
27
+ return Rec
28
+
29
+ def Pixel_F1_score(self):
30
+ assert self.confusion_matrix.shape[0] == 2
31
+ Rec = self.Pixel_Recall_Rate()
32
+ Pre = self.Pixel_Precision_Rate()
33
+ F1 = 2 * Rec * Pre / (Rec + Pre)
34
+ return F1
35
+
36
+
37
+ def calculate_per_class_metrics(self):
38
+ # Adjustments to exclude class 0 in calculations
39
+ TPs = np.diag(self.confusion_matrix)[1:] # Start from index 1 to exclude class 0
40
+ FNs = np.sum(self.confusion_matrix, axis=1)[1:] - TPs
41
+ FPs = np.sum(self.confusion_matrix, axis=0)[1:] - TPs
42
+ return TPs, FNs, FPs
43
+
44
+ def Damage_F1_socore(self):
45
+ TPs, FNs, FPs = self.calculate_per_class_metrics()
46
+ precisions = TPs / (TPs + FPs + 1e-7)
47
+ recalls = TPs / (TPs + FNs + 1e-7)
48
+ f1_scores = 2 * (precisions * recalls) / (precisions + recalls + 1e-7)
49
+ return f1_scores
50
+
51
+ def Mean_Intersection_over_Union(self):
52
+ MIoU = np.nanmean(self.Intersection_over_Union())
53
+ return MIoU
54
+
55
+ def Intersection_over_Union(self):
56
+ IoU = np.diag(self.confusion_matrix) / (
57
+ np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
58
+ np.diag(self.confusion_matrix) + 1e-7)
59
+ return IoU
60
+
61
+ def Kappa_coefficient(self):
62
+ # Number of observations (total number of classifications)
63
+ # num_total = np.array(0, dtype=np.long)
64
+ # row_sums = np.array([0, 0], dtype=np.long)
65
+ # col_sums = np.array([0, 0], dtype=np.long)
66
+ # total += np.sum(self.confusion_matrix)
67
+ # # Observed agreement (i.e., sum of diagonal elements)
68
+ # observed_agreement = np.sum(np.diag(self.confusion_matrix))
69
+ # # Compute expected agreement
70
+ # row_sums += np.sum(self.confusion_matrix, axis=0)
71
+ # col_sums += np.sum(self.confusion_matrix, axis=1)
72
+ # expected_agreement = np.sum((row_sums * col_sums) / total)
73
+ num_total = np.sum(self.confusion_matrix)
74
+ observed_accuracy = np.trace(self.confusion_matrix) / num_total
75
+ expected_accuracy = np.sum(
76
+ np.sum(self.confusion_matrix, axis=0) / num_total * np.sum(self.confusion_matrix, axis=1) / num_total)
77
+
78
+ # Calculate Cohen's kappa
79
+ kappa = (observed_accuracy - expected_accuracy) / (1 - expected_accuracy)
80
+ return kappa
81
+
82
+ def Frequency_Weighted_Intersection_over_Union(self):
83
+ freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
84
+ iu = np.diag(self.confusion_matrix) / (
85
+ np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
86
+ np.diag(self.confusion_matrix))
87
+
88
+ FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
89
+ return FWIoU
90
+
91
+ def _generate_matrix(self, gt_image, pre_image):
92
+ mask = (gt_image >= 0) & (gt_image < self.num_class)
93
+ label = self.num_class * gt_image[mask].astype('int64') + pre_image[mask]
94
+ count = np.bincount(label, minlength=self.num_class ** 2)
95
+ confusion_matrix = count.reshape(self.num_class, self.num_class)
96
+ return confusion_matrix
97
+
98
+ def add_batch(self, gt_image, pre_image):
99
+ assert gt_image.shape == pre_image.shape
100
+ self.confusion_matrix += self._generate_matrix(gt_image, pre_image)
101
+
102
+ def reset(self):
103
+ self.confusion_matrix = np.zeros((self.num_class,) * 2)