1+ """
2+ # Created: 2024-11-15 21:33
3+ # Copyright (C) 2024-now, RPL, KTH Royal Institute of Technology
4+ # Author: Qingwen Zhang (https://kin-zhang.github.io/)
5+ #
6+ # This file is part of
7+ # * DeltaFlow (https://github.com/Kin-Zhang/DeltaFlow)
8+ # * OpenSceneFlow (https://github.com/KTH-RPL/OpenSceneFlow)
9+ # If you find this repo helpful, please cite the respective publication as
10+ # listed on the above website.
11+ """
12+ import torch
13+ import torch .nn as nn
14+ import spconv .pytorch as spconv
15+ import spconv as spconv_core
16+ spconv_core .constants .SPCONV_ALLOW_TF32 = True
17+ from .encoder import DynamicVoxelizer , DynamicPillarFeatureNet
18+ import dztimer
19+
20+ class SparseVoxelNet (nn .Module ):
21+
22+ def __init__ (self , voxel_size , pseudo_image_dims , point_cloud_range ,
23+ feat_channels : int , decay_factor = 1.0 , timer = None ) -> None :
24+ super ().__init__ ()
25+ self .voxelizer = DynamicVoxelizer (voxel_size = voxel_size ,
26+ point_cloud_range = point_cloud_range )
27+ self .feature_net = DynamicPillarFeatureNet (
28+ in_channels = 3 ,
29+ feat_channels = (feat_channels , ),
30+ point_cloud_range = point_cloud_range ,
31+ voxel_size = voxel_size ,
32+ mode = 'avg' )
33+
34+ self .voxel_spatial_shape = pseudo_image_dims
35+ self .num_feature = feat_channels
36+ self .decay_factor = decay_factor
37+ if timer is None :
38+ self .timer = dztimer .Timing ()
39+ self .timer .start ("Total" )
40+ else :
41+ self .timer = timer
42+
43+ def process_batch (self , voxel_info_list , if_return_point_feats = False ):
44+ voxel_feats_list_batch = []
45+ voxel_coors_list_batch = []
46+ point_feats_lst = []
47+
48+ for batch_index , voxel_info_dict in enumerate (voxel_info_list ):
49+ points = voxel_info_dict ['points' ]
50+ coordinates = voxel_info_dict ['voxel_coords' ]
51+ voxel_feats , voxel_coors , point_feats = self .feature_net (points , coordinates )
52+ if if_return_point_feats :
53+ point_feats_lst .append (point_feats )
54+ batch_indices = torch .full ((voxel_coors .size (0 ), 1 ), batch_index , dtype = torch .long , device = voxel_coors .device )
55+ voxel_coors_batch = torch .cat ([batch_indices , voxel_coors [:, [2 , 1 , 0 ]]], dim = 1 )
56+ voxel_feats_list_batch .append (voxel_feats )
57+ voxel_coors_list_batch .append (voxel_coors_batch )
58+
59+ voxel_feats_sp = torch .cat (voxel_feats_list_batch , dim = 0 )
60+ coors_batch_sp = torch .cat (voxel_coors_list_batch , dim = 0 ).to (dtype = torch .int32 )
61+
62+ if if_return_point_feats :
63+ return voxel_feats_sp , coors_batch_sp , point_feats_lst
64+
65+ return voxel_feats_sp , coors_batch_sp
66+
67+ def forward (self , input_dict ) -> torch .Tensor :
68+ bz_ = len (input_dict ['pc0s' ])
69+ frame_keys = sorted ([key for key in input_dict .keys () if key .startswith ('pch' )], reverse = True )
70+ frame_keys += ['pc0s' ]
71+
72+ pc1_voxel_info_list = self .voxelizer (input_dict ['pc1s' ])
73+ pc1_voxel_feats_sp , pc1_coors_batch_sp = self .process_batch (pc1_voxel_info_list )
74+ pc1s_num_voxels = pc1_voxel_feats_sp .shape [0 ]
75+ sparse_max_size = [bz_ , * self .voxel_spatial_shape , self .num_feature ]
76+ sparse_pc1 = torch .sparse_coo_tensor (pc1_coors_batch_sp .t (), pc1_voxel_feats_sp , size = sparse_max_size )
77+ sparse_diff = torch .sparse_coo_tensor (pc1_coors_batch_sp .t (), pc1_voxel_feats_sp * 0.0 , size = sparse_max_size )
78+ pch1s_3dvoxel_infos_lst = None
79+ pc0_point_feats_lst = []
80+
81+ # (0, 'pch2s'), (1, 'pch1s'), (2, 'pc0s')
82+ # reversed: (0, 'pc0s'), (1, 'pch1s'), (2, 'pch2s')
83+ for time_index , frame_key in enumerate (reversed (frame_keys )):
84+ self .timer [0 ].start ("Point Feature Voxelize" )
85+ pc = input_dict [frame_key ]
86+ voxel_info_list = self .voxelizer (pc )
87+
88+ if frame_key == 'pc0s' :
89+ voxel_feats_sp , coors_batch_sp , pc0_point_feats_lst = self .process_batch (voxel_info_list , if_return_point_feats = True )
90+ else :
91+ voxel_feats_sp , coors_batch_sp = self .process_batch (voxel_info_list )
92+
93+ sparse_pcx = torch .sparse_coo_tensor (coors_batch_sp .t (), voxel_feats_sp , size = sparse_max_size )
94+ sparse_diff = sparse_diff + (sparse_pc1 - sparse_pcx ) * pow (self .decay_factor , time_index )
95+ self .timer [0 ].stop ()
96+
97+ if frame_key == 'pc0s' :
98+ pc0s_3dvoxel_infos_lst = voxel_info_list
99+ pc0s_num_voxels = voxel_feats_sp .shape [0 ]
100+ elif frame_key == 'pch1s' :
101+ pch1s_3dvoxel_infos_lst = voxel_info_list
102+
103+ self .timer [2 ].start ("D_Delta_Sparse" )
104+ features = sparse_diff .coalesce ().values () / (time_index + 1 )
105+ indices = sparse_diff .coalesce ().indices ().t ().to (dtype = torch .int32 )
106+ all_pcdiff_sparse = spconv .SparseConvTensor (features .contiguous (), indices .contiguous (), self .voxel_spatial_shape , bz_ )
107+ self .timer [2 ].stop ()
108+
109+ output = {
110+ 'delta_sparse' : all_pcdiff_sparse ,
111+ 'pch1_3dvoxel_infos_lst' : pch1s_3dvoxel_infos_lst ,
112+ 'pc0_3dvoxel_infos_lst' : pc0s_3dvoxel_infos_lst ,
113+ 'pc0_point_feats_lst' : pc0_point_feats_lst ,
114+ 'pc0_num_voxels' : pc0s_num_voxels ,
115+ 'pc1_3dvoxel_infos_lst' : pc1_voxel_info_list ,
116+ 'pc1_num_voxels' : pc1s_num_voxels ,
117+ 'd_num_voxels' : indices .shape [0 ]
118+ }
119+ return output
120+
121+ class BasicConvolutionBlock (nn .Module ):
122+ def __init__ (self , inc , outc , ks = 3 , stride = 1 , dilation = 1 , padding = 0 , indice_key = None ):
123+ super ().__init__ ()
124+ self .net = spconv .SparseSequential (
125+ spconv .SparseConv3d (inc , outc , kernel_size = ks , stride = stride , dilation = dilation , padding = padding , bias = False , \
126+ indice_key = indice_key , algo = spconv .ConvAlgo .Native ),
127+ nn .BatchNorm1d (outc ),
128+ nn .ReLU (inplace = True )
129+ )
130+ def forward (self , x ):
131+ return self .net (x )
132+
133+ class BasicDeconvolutionBlock (nn .Module ):
134+ def __init__ (self , inc , outc , indice_key , ks = 3 ):
135+ super ().__init__ ()
136+ self .net = spconv .SparseSequential (
137+ spconv .SparseInverseConv3d (inc , outc , kernel_size = ks , indice_key = indice_key , bias = False , algo = spconv .ConvAlgo .Native ),
138+ nn .BatchNorm1d (outc ),
139+ nn .ReLU (inplace = True )
140+ )
141+ def forward (self , x ):
142+ return self .net (x )
143+
144+ class ResidualBlock (nn .Module ):
145+ expansion = 1
146+ def __init__ (self , inc , outc , ks = 3 , stride = 1 , dilation = 1 , padding = 0 ):
147+ super ().__init__ ()
148+ self .net = spconv .SparseSequential (
149+ spconv .SubMConv3d (inc , outc , kernel_size = ks , stride = stride , dilation = dilation , padding = padding , bias = False , \
150+ algo = spconv .ConvAlgo .Native ),
151+ nn .BatchNorm1d (outc ),
152+ nn .ReLU (inplace = True ),
153+ spconv .SubMConv3d (outc , outc , kernel_size = ks , stride = stride , dilation = dilation , padding = padding , bias = False , \
154+ algo = spconv .ConvAlgo .Native ),
155+ nn .BatchNorm1d (outc )
156+ )
157+
158+ if inc == (outc * self .expansion ) and stride == 1 :
159+ self .downsample = None
160+ else :
161+ self .downsample = spconv .SparseSequential (
162+ spconv .SubMConv3d (inc , outc , kernel_size = 1 , dilation = 1 ,
163+ stride = stride , algo = spconv .ConvAlgo .Native ),
164+ nn .BatchNorm1d (outc )
165+ )
166+ self .relu = nn .ReLU (inplace = True )
167+ def forward (self , x ):
168+ identity = x .features
169+ out = self .net (x )
170+ if self .downsample is not None :
171+ identity = self .downsample (x ).features
172+ out = out .replace_feature (out .features + identity )
173+ out = out .replace_feature (self .relu (out .features ))
174+
175+ return out
176+
177+ '''
178+ Reference when I wrote MinkUNet:
179+ * https://github.com/PJLab-ADG/OpenPCSeg/blob/master/pcseg/model/segmentor/voxel/minkunet/minkunet.py
180+ * https://github.com/open-mmlab/mmdetection3d/blob/main/mmdet3d/models/backbones/minkunet_backbone.py
181+ * https://github.com/mit-han-lab/spvnas/blob/master/core/models/semantic_kitti/minkunet.py
182+ '''
183+ class MinkUNet (nn .Module ):
184+ def __init__ (self ,
185+ cs = [16 , 32 , 64 , 128 , 256 , 256 , 128 , 64 , 32 , 16 ],
186+ num_layer = [2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 ]):
187+ super ().__init__ ()
188+
189+ inc = cs [0 ]
190+ cs = cs [1 :] # remove the first input channel after conv_input
191+ self .block = ResidualBlock
192+
193+ self .conv_input = spconv .SparseSequential (
194+ spconv .SubMConv3d (inc , cs [0 ], kernel_size = 3 , stride = 1 , padding = 1 , bias = False , \
195+ indice_key = "subm0" , algo = spconv .ConvAlgo .Native ),
196+ nn .BatchNorm1d (cs [0 ]),
197+ nn .ReLU (inplace = True ),
198+
199+ spconv .SubMConv3d (cs [0 ], cs [0 ], kernel_size = 3 , stride = 1 , padding = 1 , bias = False , \
200+ indice_key = "subm0" , algo = spconv .ConvAlgo .Native ),
201+ nn .BatchNorm1d (cs [0 ]),
202+ nn .ReLU (inplace = True )
203+ )
204+ self .in_channels = cs [0 ]
205+
206+ self .stage1 = nn .Sequential (
207+ BasicConvolutionBlock (self .in_channels , self .in_channels , ks = 2 , stride = 2 , indice_key = "subm1" ),
208+ * self ._make_layer (self .block , cs [1 ], num_layer [0 ])
209+ )
210+ # inside every make_layer: self.in_channels = out_channels * block.expansion
211+ self .stage2 = nn .Sequential (
212+ BasicConvolutionBlock (self .in_channels , self .in_channels , ks = 2 , stride = 2 , indice_key = "subm2" ),
213+ * self ._make_layer (self .block , cs [2 ], num_layer [1 ])
214+ )
215+ self .stage3 = nn .Sequential (
216+ BasicConvolutionBlock (self .in_channels , self .in_channels , ks = 2 , stride = 2 , indice_key = "subm3" ),
217+ * self ._make_layer (self .block , cs [3 ], num_layer [2 ])
218+ )
219+ self .stage4 = nn .Sequential (
220+ BasicConvolutionBlock (self .in_channels , self .in_channels , ks = 2 , stride = 2 , indice_key = "subm4" ),
221+ * self ._make_layer (self .block , cs [4 ], num_layer [3 ])
222+ )
223+
224+ self .up1 = [BasicDeconvolutionBlock (self .in_channels , cs [5 ], ks = 2 , indice_key = "subm4" )]
225+ self .in_channels = cs [5 ] + cs [3 ] * self .block .expansion
226+ self .up1 .append (nn .Sequential (* self ._make_layer (self .block , cs [5 ], num_layer [4 ])))
227+ self .up1 = nn .ModuleList (self .up1 )
228+
229+ self .up2 = [BasicDeconvolutionBlock (cs [5 ], cs [6 ], ks = 2 , indice_key = "subm3" )]
230+ self .in_channels = cs [6 ] + cs [2 ] * self .block .expansion
231+ self .up2 .append (nn .Sequential (* self ._make_layer (self .block , cs [6 ], num_layer [5 ])))
232+ self .up2 = nn .ModuleList (self .up2 )
233+
234+ self .up3 = [BasicDeconvolutionBlock (cs [6 ], cs [7 ], ks = 2 , indice_key = "subm2" )]
235+ self .in_channels = cs [7 ] + cs [1 ] * self .block .expansion
236+ self .up3 .append (nn .Sequential (* self ._make_layer (self .block , cs [7 ], num_layer [6 ])))
237+ self .up3 = nn .ModuleList (self .up3 )
238+
239+ self .up4 = [BasicDeconvolutionBlock (cs [7 ], cs [8 ], ks = 2 , indice_key = "subm1" )]
240+ self .in_channels = cs [8 ] + cs [0 ] * self .block .expansion
241+ self .up4 .append (nn .Sequential (* self ._make_layer (self .block , cs [8 ], num_layer [7 ])))
242+ self .up4 = nn .ModuleList (self .up4 )
243+
244+ def _make_layer (self , block , out_channels , num_block , stride = 1 ):
245+ layers = []
246+ layers .append (
247+ block (self .in_channels , out_channels , stride = stride )
248+ )
249+ self .in_channels = out_channels * block .expansion
250+ for _ in range (1 , num_block ):
251+ layers .append (
252+ block (self .in_channels , out_channels )
253+ )
254+ return layers
255+
256+ def forward (self , x ):
257+ x = self .conv_input (x )
258+ x1 = self .stage1 (x )
259+ x2 = self .stage2 (x1 )
260+ x3 = self .stage3 (x2 )
261+ x4 = self .stage4 (x3 )
262+
263+ y1 = self .up1 [0 ](x4 )
264+ y1 = y1 .replace_feature (torch .cat ([y1 .features , x3 .features ], dim = 1 ))
265+ y1 = self .up1 [1 ](y1 )
266+
267+ y2 = self .up2 [0 ](y1 )
268+ y2 = y2 .replace_feature (torch .cat ([y2 .features , x2 .features ], dim = 1 ))
269+ y2 = self .up2 [1 ](y2 )
270+
271+ y3 = self .up3 [0 ](y2 )
272+ y3 = y3 .replace_feature (torch .cat ([y3 .features , x1 .features ], dim = 1 ))
273+ y3 = self .up3 [1 ](y3 ) # Dense shape: [B, C, X, Y, Z]; [B, 32, 256, 256, 16]
274+
275+ y4 = self .up4 [0 ](y3 )
276+ y4 = y4 .replace_feature (torch .cat ([y4 .features , x .features ], dim = 1 ))
277+ y4 = self .up4 [1 ](y4 )
278+
279+ return y4
0 commit comments