Skip to content

Commit 99b1103

Browse files
committed
feat(deltaflow): update deltaflow model file.
* checked with trained weight.
1 parent 5d89648 commit 99b1103

5 files changed

Lines changed: 433 additions & 4 deletions

File tree

src/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
print(f"Detail error message\033[0m: {e}. Just ignore this warning if code runs without these models.")
2323

2424
# following need install extra package:
25-
# * pip install spconv-cu117
25+
# * pip install spconv-cu118
2626
try:
27+
from .deltaflow import DeltaFlow
2728
from .flow4d import Flow4D
2829
except ImportError as e:
2930
print("\033[93m--- WARNING [model]: Model with SparseConv is not imported, as it requires spconv lib which is not installed.")

src/models/basic/decoder.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,4 +271,45 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
271271
batchnorm_res = conv_res
272272
else:
273273
batchnorm_res = self.batchnorm(conv_res)
274-
return self.nonlinearity(batchnorm_res)
274+
return self.nonlinearity(batchnorm_res)
275+
276+
"""
277+
Note(2024/7/18 21:11 Qingwen):
278+
This is the decoder idea from DeFlow: https://github.com/KTH-RPL/DeFlow
279+
If you use/find this helpful, please cite the respective publication as listed on the above website.
280+
"""
281+
class SparseGRUHead(ConvGRUDecoder):
282+
def __init__(self, voxel_feat_dim: int = 96, point_feat_dim: int = 32, num_iters = 2):
283+
super(SparseGRUHead, self).__init__(pseudoimage_channels=point_feat_dim, num_iters=num_iters)
284+
self.offset_encoder = None # otherwise DDP may find parameter not used in training.
285+
self.gru = ConvGRU(input_dim=point_feat_dim, hidden_dim=voxel_feat_dim)
286+
self.decoder = nn.Sequential(
287+
nn.Linear(voxel_feat_dim + point_feat_dim, 32),
288+
nn.BatchNorm1d(32),
289+
nn.GELU(),
290+
nn.Linear(32, 3))
291+
292+
def forward_single(self, voxel_feat, voxel_coords, point_offsets):
293+
# [N, voxel_feat_dim] -> [N, voxel_feat_dim, 1]
294+
concatenated_vectors = (voxel_feat[:, voxel_coords[:,2], voxel_coords[:,1], voxel_coords[:,0]].T).unsqueeze(2)
295+
for itr in range(self.num_iters):
296+
concatenated_vectors = self.gru(concatenated_vectors, point_offsets.unsqueeze(2))
297+
298+
flow = self.decoder(torch.cat([concatenated_vectors.squeeze(2), point_offsets], dim=1))
299+
return flow
300+
301+
def forward(self, sparse_tensor, voxelizer_infos, pc0_point_feats_lst):
302+
303+
voxel_feats = sparse_tensor.dense()
304+
305+
flow_outputs = []
306+
batch_idx = 0
307+
for voxelizer_info in voxelizer_infos:
308+
voxel_coords = voxelizer_info["voxel_coords"]
309+
point_feat = pc0_point_feats_lst[batch_idx]
310+
voxel_feat = voxel_feats[batch_idx, :]
311+
flow = self.forward_single(voxel_feat, voxel_coords, point_feat)
312+
batch_idx += 1
313+
flow_outputs.append(flow)
314+
315+
return flow_outputs

src/models/basic/sparse_encoder.py

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
"""
2+
# Created: 2024-11-15 21:33
3+
# Copyright (C) 2024-now, RPL, KTH Royal Institute of Technology
4+
# Author: Qingwen Zhang (https://kin-zhang.github.io/)
5+
#
6+
# This file is part of
7+
# * DeltaFlow (https://github.com/Kin-Zhang/DeltaFlow)
8+
# * OpenSceneFlow (https://github.com/KTH-RPL/OpenSceneFlow)
9+
# If you find this repo helpful, please cite the respective publication as
10+
# listed on the above website.
11+
"""
12+
import torch
13+
import torch.nn as nn
14+
import spconv.pytorch as spconv
15+
import spconv as spconv_core
16+
spconv_core.constants.SPCONV_ALLOW_TF32 = True
17+
from .encoder import DynamicVoxelizer, DynamicPillarFeatureNet
18+
import dztimer
19+
20+
class SparseVoxelNet(nn.Module):
21+
22+
def __init__(self, voxel_size, pseudo_image_dims, point_cloud_range,
23+
feat_channels: int, decay_factor=1.0, timer=None) -> None:
24+
super().__init__()
25+
self.voxelizer = DynamicVoxelizer(voxel_size=voxel_size,
26+
point_cloud_range=point_cloud_range)
27+
self.feature_net = DynamicPillarFeatureNet(
28+
in_channels=3,
29+
feat_channels=(feat_channels, ),
30+
point_cloud_range=point_cloud_range,
31+
voxel_size=voxel_size,
32+
mode='avg')
33+
34+
self.voxel_spatial_shape = pseudo_image_dims
35+
self.num_feature = feat_channels
36+
self.decay_factor = decay_factor
37+
if timer is None:
38+
self.timer = dztimer.Timing()
39+
self.timer.start("Total")
40+
else:
41+
self.timer = timer
42+
43+
def process_batch(self, voxel_info_list, if_return_point_feats=False):
44+
voxel_feats_list_batch = []
45+
voxel_coors_list_batch = []
46+
point_feats_lst = []
47+
48+
for batch_index, voxel_info_dict in enumerate(voxel_info_list):
49+
points = voxel_info_dict['points']
50+
coordinates = voxel_info_dict['voxel_coords']
51+
voxel_feats, voxel_coors, point_feats = self.feature_net(points, coordinates)
52+
if if_return_point_feats:
53+
point_feats_lst.append(point_feats)
54+
batch_indices = torch.full((voxel_coors.size(0), 1), batch_index, dtype=torch.long, device=voxel_coors.device)
55+
voxel_coors_batch = torch.cat([batch_indices, voxel_coors[:, [2, 1, 0]]], dim=1)
56+
voxel_feats_list_batch.append(voxel_feats)
57+
voxel_coors_list_batch.append(voxel_coors_batch)
58+
59+
voxel_feats_sp = torch.cat(voxel_feats_list_batch, dim=0)
60+
coors_batch_sp = torch.cat(voxel_coors_list_batch, dim=0).to(dtype=torch.int32)
61+
62+
if if_return_point_feats:
63+
return voxel_feats_sp, coors_batch_sp, point_feats_lst
64+
65+
return voxel_feats_sp, coors_batch_sp
66+
67+
def forward(self, input_dict) -> torch.Tensor:
68+
bz_ = len(input_dict['pc0s'])
69+
frame_keys = sorted([key for key in input_dict.keys() if key.startswith('pch')], reverse=True)
70+
frame_keys += ['pc0s']
71+
72+
pc1_voxel_info_list = self.voxelizer(input_dict['pc1s'])
73+
pc1_voxel_feats_sp, pc1_coors_batch_sp = self.process_batch(pc1_voxel_info_list)
74+
pc1s_num_voxels = pc1_voxel_feats_sp.shape[0]
75+
sparse_max_size = [bz_, *self.voxel_spatial_shape, self.num_feature]
76+
sparse_pc1 = torch.sparse_coo_tensor(pc1_coors_batch_sp.t(), pc1_voxel_feats_sp, size=sparse_max_size)
77+
sparse_diff = torch.sparse_coo_tensor(pc1_coors_batch_sp.t(), pc1_voxel_feats_sp * 0.0, size=sparse_max_size)
78+
pch1s_3dvoxel_infos_lst = None
79+
pc0_point_feats_lst = []
80+
81+
# (0, 'pch2s'), (1, 'pch1s'), (2, 'pc0s')
82+
# reversed: (0, 'pc0s'), (1, 'pch1s'), (2, 'pch2s')
83+
for time_index, frame_key in enumerate(reversed(frame_keys)):
84+
self.timer[0].start("Point Feature Voxelize")
85+
pc = input_dict[frame_key]
86+
voxel_info_list = self.voxelizer(pc)
87+
88+
if frame_key == 'pc0s':
89+
voxel_feats_sp, coors_batch_sp, pc0_point_feats_lst = self.process_batch(voxel_info_list, if_return_point_feats=True)
90+
else:
91+
voxel_feats_sp, coors_batch_sp = self.process_batch(voxel_info_list)
92+
93+
sparse_pcx = torch.sparse_coo_tensor(coors_batch_sp.t(), voxel_feats_sp, size=sparse_max_size)
94+
sparse_diff = sparse_diff + (sparse_pc1 - sparse_pcx) * pow(self.decay_factor, time_index)
95+
self.timer[0].stop()
96+
97+
if frame_key == 'pc0s':
98+
pc0s_3dvoxel_infos_lst = voxel_info_list
99+
pc0s_num_voxels = voxel_feats_sp.shape[0]
100+
elif frame_key == 'pch1s':
101+
pch1s_3dvoxel_infos_lst = voxel_info_list
102+
103+
self.timer[2].start("D_Delta_Sparse")
104+
features = sparse_diff.coalesce().values() / (time_index + 1)
105+
indices = sparse_diff.coalesce().indices().t().to(dtype=torch.int32)
106+
all_pcdiff_sparse = spconv.SparseConvTensor(features.contiguous(), indices.contiguous(), self.voxel_spatial_shape, bz_)
107+
self.timer[2].stop()
108+
109+
output = {
110+
'delta_sparse': all_pcdiff_sparse,
111+
'pch1_3dvoxel_infos_lst': pch1s_3dvoxel_infos_lst,
112+
'pc0_3dvoxel_infos_lst': pc0s_3dvoxel_infos_lst,
113+
'pc0_point_feats_lst': pc0_point_feats_lst,
114+
'pc0_num_voxels': pc0s_num_voxels,
115+
'pc1_3dvoxel_infos_lst': pc1_voxel_info_list,
116+
'pc1_num_voxels': pc1s_num_voxels,
117+
'd_num_voxels': indices.shape[0]
118+
}
119+
return output
120+
121+
class BasicConvolutionBlock(nn.Module):
122+
def __init__(self, inc, outc, ks=3, stride=1, dilation=1, padding=0, indice_key=None):
123+
super().__init__()
124+
self.net = spconv.SparseSequential(
125+
spconv.SparseConv3d(inc, outc, kernel_size=ks, stride=stride, dilation=dilation, padding=padding, bias=False, \
126+
indice_key=indice_key, algo=spconv.ConvAlgo.Native),
127+
nn.BatchNorm1d(outc),
128+
nn.ReLU(inplace=True)
129+
)
130+
def forward(self, x):
131+
return self.net(x)
132+
133+
class BasicDeconvolutionBlock(nn.Module):
134+
def __init__(self, inc, outc, indice_key, ks=3):
135+
super().__init__()
136+
self.net = spconv.SparseSequential(
137+
spconv.SparseInverseConv3d(inc, outc, kernel_size=ks, indice_key=indice_key, bias=False, algo=spconv.ConvAlgo.Native),
138+
nn.BatchNorm1d(outc),
139+
nn.ReLU(inplace=True)
140+
)
141+
def forward(self, x):
142+
return self.net(x)
143+
144+
class ResidualBlock(nn.Module):
145+
expansion = 1
146+
def __init__(self, inc, outc, ks=3, stride=1, dilation=1, padding=0):
147+
super().__init__()
148+
self.net = spconv.SparseSequential(
149+
spconv.SubMConv3d(inc, outc, kernel_size=ks, stride=stride, dilation=dilation, padding=padding, bias=False, \
150+
algo=spconv.ConvAlgo.Native),
151+
nn.BatchNorm1d(outc),
152+
nn.ReLU(inplace=True),
153+
spconv.SubMConv3d(outc, outc, kernel_size=ks, stride=stride, dilation=dilation, padding=padding, bias=False, \
154+
algo=spconv.ConvAlgo.Native),
155+
nn.BatchNorm1d(outc)
156+
)
157+
158+
if inc == (outc * self.expansion) and stride == 1:
159+
self.downsample = None
160+
else:
161+
self.downsample = spconv.SparseSequential(
162+
spconv.SubMConv3d(inc, outc, kernel_size=1, dilation=1,
163+
stride=stride, algo=spconv.ConvAlgo.Native),
164+
nn.BatchNorm1d(outc)
165+
)
166+
self.relu = nn.ReLU(inplace=True)
167+
def forward(self, x):
168+
identity = x.features
169+
out = self.net(x)
170+
if self.downsample is not None:
171+
identity = self.downsample(x).features
172+
out = out.replace_feature(out.features + identity)
173+
out = out.replace_feature(self.relu(out.features))
174+
175+
return out
176+
177+
'''
178+
Reference when I wrote MinkUNet:
179+
* https://github.com/PJLab-ADG/OpenPCSeg/blob/master/pcseg/model/segmentor/voxel/minkunet/minkunet.py
180+
* https://github.com/open-mmlab/mmdetection3d/blob/main/mmdet3d/models/backbones/minkunet_backbone.py
181+
* https://github.com/mit-han-lab/spvnas/blob/master/core/models/semantic_kitti/minkunet.py
182+
'''
183+
class MinkUNet(nn.Module):
184+
def __init__(self,
185+
cs=[16, 32, 64, 128, 256, 256, 128, 64, 32, 16],
186+
num_layer=[2, 2, 2, 2, 2, 2, 2, 2, 2]):
187+
super().__init__()
188+
189+
inc = cs[0]
190+
cs = cs[1:] # remove the first input channel after conv_input
191+
self.block = ResidualBlock
192+
193+
self.conv_input = spconv.SparseSequential(
194+
spconv.SubMConv3d(inc, cs[0], kernel_size=3, stride=1, padding=1, bias=False, \
195+
indice_key="subm0", algo=spconv.ConvAlgo.Native),
196+
nn.BatchNorm1d(cs[0]),
197+
nn.ReLU(inplace=True),
198+
199+
spconv.SubMConv3d(cs[0], cs[0], kernel_size=3, stride=1, padding=1, bias=False, \
200+
indice_key="subm0", algo=spconv.ConvAlgo.Native),
201+
nn.BatchNorm1d(cs[0]),
202+
nn.ReLU(inplace=True)
203+
)
204+
self.in_channels = cs[0]
205+
206+
self.stage1 = nn.Sequential(
207+
BasicConvolutionBlock(self.in_channels, self.in_channels, ks=2, stride=2, indice_key="subm1"),
208+
*self._make_layer(self.block, cs[1], num_layer[0])
209+
)
210+
# inside every make_layer: self.in_channels = out_channels * block.expansion
211+
self.stage2 = nn.Sequential(
212+
BasicConvolutionBlock(self.in_channels, self.in_channels, ks=2, stride=2, indice_key="subm2"),
213+
*self._make_layer(self.block, cs[2], num_layer[1])
214+
)
215+
self.stage3 = nn.Sequential(
216+
BasicConvolutionBlock(self.in_channels, self.in_channels, ks=2, stride=2, indice_key="subm3"),
217+
*self._make_layer(self.block, cs[3], num_layer[2])
218+
)
219+
self.stage4 = nn.Sequential(
220+
BasicConvolutionBlock(self.in_channels, self.in_channels, ks=2, stride=2, indice_key="subm4"),
221+
*self._make_layer(self.block, cs[4], num_layer[3])
222+
)
223+
224+
self.up1 = [BasicDeconvolutionBlock(self.in_channels, cs[5], ks=2, indice_key="subm4")]
225+
self.in_channels = cs[5] + cs[3] * self.block.expansion
226+
self.up1.append(nn.Sequential(*self._make_layer(self.block, cs[5], num_layer[4])))
227+
self.up1 = nn.ModuleList(self.up1)
228+
229+
self.up2 = [BasicDeconvolutionBlock(cs[5], cs[6], ks=2, indice_key="subm3")]
230+
self.in_channels = cs[6] + cs[2] * self.block.expansion
231+
self.up2.append(nn.Sequential(*self._make_layer(self.block, cs[6], num_layer[5])))
232+
self.up2 = nn.ModuleList(self.up2)
233+
234+
self.up3 = [BasicDeconvolutionBlock(cs[6], cs[7], ks=2, indice_key="subm2")]
235+
self.in_channels = cs[7] + cs[1] * self.block.expansion
236+
self.up3.append(nn.Sequential(*self._make_layer(self.block, cs[7], num_layer[6])))
237+
self.up3 = nn.ModuleList(self.up3)
238+
239+
self.up4 = [BasicDeconvolutionBlock(cs[7], cs[8], ks=2, indice_key="subm1")]
240+
self.in_channels = cs[8] + cs[0] * self.block.expansion
241+
self.up4.append(nn.Sequential(*self._make_layer(self.block, cs[8], num_layer[7])))
242+
self.up4 = nn.ModuleList(self.up4)
243+
244+
def _make_layer(self, block, out_channels, num_block, stride=1):
245+
layers = []
246+
layers.append(
247+
block(self.in_channels, out_channels, stride=stride)
248+
)
249+
self.in_channels = out_channels * block.expansion
250+
for _ in range(1, num_block):
251+
layers.append(
252+
block(self.in_channels, out_channels)
253+
)
254+
return layers
255+
256+
def forward(self, x):
257+
x = self.conv_input(x)
258+
x1 = self.stage1(x)
259+
x2 = self.stage2(x1)
260+
x3 = self.stage3(x2)
261+
x4 = self.stage4(x3)
262+
263+
y1 = self.up1[0](x4)
264+
y1 = y1.replace_feature(torch.cat([y1.features, x3.features], dim=1))
265+
y1 = self.up1[1](y1)
266+
267+
y2 = self.up2[0](y1)
268+
y2 = y2.replace_feature(torch.cat([y2.features, x2.features], dim=1))
269+
y2 = self.up2[1](y2)
270+
271+
y3 = self.up3[0](y2)
272+
y3 = y3.replace_feature(torch.cat([y3.features, x1.features], dim=1))
273+
y3 = self.up3[1](y3) # Dense shape: [B, C, X, Y, Z]; [B, 32, 256, 256, 16]
274+
275+
y4 = self.up4[0](y3)
276+
y4 = y4.replace_feature(torch.cat([y4.features, x.features], dim=1))
277+
y4 = self.up4[1](y4)
278+
279+
return y4

0 commit comments

Comments
 (0)