File size: 4,123 Bytes
cf7f9c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d3127
cf7f9c0
81d3127
cf7f9c0
 
 
 
 
 
 
 
 
 
 
 
 
81d3127
 
 
 
cf7f9c0
69d439b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
import io





def colorize_depth_map(depth, mask=None, reverse_color=False, color_map="Spectral"):
    cm = matplotlib.colormaps[color_map]
    #* the depth is normalized by defailt
    
    if reverse_color:
        img_colored_np = cm(1 - depth, bytes=False)[:, :, 0:3]
    else:
        img_colored_np = cm(depth, bytes=False)[:, :, 0:3]

    depth_colored = (img_colored_np * 255).astype(np.uint8)
    if mask is not None:
        masked_image = np.zeros_like(depth_colored)
        masked_image[mask] = depth_colored[mask]
        depth_colored_img = Image.fromarray(masked_image)
    else:
        depth_colored_img = Image.fromarray(depth_colored)
        
    return depth_colored_img




def depth2disparity(depth, return_mask=False):
    if isinstance(depth, torch.Tensor):
        disparity = torch.zeros_like(depth)
    elif isinstance(depth, np.ndarray):
        disparity = np.zeros_like(depth)
    non_negtive_mask = depth > 0
    disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
    if return_mask:
        return disparity, non_negtive_mask
    else:
        return disparity


def disparity2depth(disparity, **kwargs):
    return depth2disparity(disparity, **kwargs)




def align_depth_least_square(
    gt_arr: np.ndarray,
    pred_arr: np.ndarray,
    valid_mask_arr: np.ndarray,
    return_scale_shift=True,
    max_resolution=None,
):
    ori_shape = pred_arr.shape  # input shape

    gt = gt_arr.squeeze()  # [H, W]
    pred = pred_arr.squeeze()
    valid_mask = valid_mask_arr.squeeze()

    # Downsample
    if max_resolution is not None:
        scale_factor = np.min(max_resolution / np.array(ori_shape[-2:]))
        if scale_factor < 1:
            downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
            gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy()
            pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy()
            valid_mask = (
                downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float())
                .bool()
                .numpy()
            )

    assert (
        gt.shape == pred.shape == valid_mask.shape
    ), f"{gt.shape}, {pred.shape}, {valid_mask.shape}"

    gt_masked = gt[valid_mask].reshape((-1, 1))
    pred_masked = pred[valid_mask].reshape((-1, 1))

    # numpy solver
    _ones = np.ones_like(pred_masked)
    A = np.concatenate([pred_masked, _ones], axis=-1)
    X = np.linalg.lstsq(A, gt_masked, rcond=None)[0]
    scale, shift = X

    aligned_pred = pred_arr * scale + shift

    # restore dimensions
    aligned_pred = aligned_pred.reshape(ori_shape)

    if return_scale_shift:
        return aligned_pred, scale, shift
    else:
        return aligned_pred



def transfer_pred_disp2depth(all_pred_disparity, all_gt_depths, all_masks, return_scale_shift=False):
    gt_disparity,gt_non_neg_mask = depth2disparity(all_gt_depths, return_mask=True)

    pred_non_neg_mask = all_pred_disparity > 0
    valid_non_neg_mask = pred_non_neg_mask & gt_non_neg_mask & all_masks

    align_disp_pred,scale,shift = align_depth_least_square( gt_arr=gt_disparity,
            pred_arr=all_pred_disparity,
            valid_mask_arr=valid_non_neg_mask,
            return_scale_shift=True,
            max_resolution=None,)

    align_disp_pred = np.clip(
            align_disp_pred, a_min=1e-3, a_max=None
        )  # avoid 0 disparity
    all_pred_depths = disparity2depth(align_disp_pred)
    if return_scale_shift:
        return all_pred_depths, scale, shift
    else:
        return all_pred_depths

    


"""
not gt needed to transfer
"""
def transfer_pred_disp2depth_v2(all_pred_disparity, scale, shift):
    

    ori_shape =all_pred_disparity.shape
    tmp = all_pred_disparity.squeeze()
    tmp = tmp * scale + shift
    align_disp_pred = tmp.reshape(ori_shape)
    
    align_disp_pred = np.clip(
            align_disp_pred, a_min=1e-3, a_max=None
        )  # avoid 0 disparity
    all_pred_depths = disparity2depth(align_disp_pred)

    return all_pred_depths