|
| 1 | +#include <vector> |
| 2 | +#include "hist_cuda_core.cuh" |
| 3 | + |
| 4 | +#include <ATen/ATen.h> |
| 5 | +#include <ATen/cuda/CUDAContext.h> |
| 6 | +#include <cuda.h> |
| 7 | +#include <cuda_runtime.h> |
| 8 | + |
| 9 | +// #include <THC/THC.h> |
| 10 | +// #include <THC/THCAtomics.cuh> |
| 11 | +// #include <THC/THCDeviceUtils.cuh> |
| 12 | + |
| 13 | +// extern THCState *state; |
| 14 | + |
| 15 | +// author: Charles Shang |
| 16 | +// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu |
| 17 | + |
| 18 | + |
| 19 | +at::Tensor |
| 20 | +hist_cuda(const at::Tensor &X, const at::Tensor &Y, |
| 21 | + const float min_x, const float min_y, const float min_z, |
| 22 | + const float max_x, const float max_y, const float max_z, |
| 23 | + const int len_x, const int len_y, const int len_z, |
| 24 | + const int mini_batch |
| 25 | + ) |
| 26 | +{ |
| 27 | + // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); |
| 28 | + |
| 29 | + AT_ASSERTM(X.is_contiguous(), "input tensor has to be contiguous"); |
| 30 | + AT_ASSERTM(Y.is_contiguous(), "input tensor has to be contiguous"); |
| 31 | + |
| 32 | + AT_ASSERTM(X.type().is_cuda(), "input must be a CUDA tensor"); |
| 33 | + AT_ASSERTM(Y.type().is_cuda(), "input must be a CUDA tensor"); |
| 34 | + |
| 35 | + const int batch = X.size(0); |
| 36 | + const int num_X = X.size(1); |
| 37 | + const int dim = X.size(2); |
| 38 | + const int num_Y = Y.size(1); |
| 39 | + |
| 40 | + AT_ASSERTM((X.size(0) == Y.size(0)), "batch_X (%d) != batch_Y (%d).", X.size(0), Y.size(0)); |
| 41 | + AT_ASSERTM((X.size(2) == Y.size(2)), "dim_X (%d) != dim_Y (%d).", X.size(2), Y.size(2)); |
| 42 | + |
| 43 | + AT_ASSERTM((dim == 4), "dim (%d) != 4; 3 for (x, y, z); 1 for indicator,padded or not.", dim); |
| 44 | + |
| 45 | + // printf("len: %d %d %f \n", len_x, len_y, len_z); |
| 46 | + // printf("hist cuda coord: %f, %f, %f; %f, %f, %f; %f, %f, %f. \n", val_x, val_y, val_z, p_x, p_y, p_z, len_x, len_y, len_z); |
| 47 | + |
| 48 | + // auto bins = at::zeros({batch, len_x, len_y, len_z}, X.options()); |
| 49 | + // AT_DISPATCH_FLOATING_TYPES(X.type(), "hist_cuda_core", ([&] { |
| 50 | + // hist_cuda_core(at::cuda::getCurrentCUDAStream(), |
| 51 | + // X.data<scalar_t>(), Y.data<scalar_t>(), |
| 52 | + // batch, dim, num_X, num_Y, |
| 53 | + // min_x, min_y, min_z, |
| 54 | + // max_x, max_y, max_z, |
| 55 | + // len_x, len_y, len_z, |
| 56 | + // bins.data<scalar_t>()); |
| 57 | + // })); |
| 58 | + |
| 59 | + auto bins = at::zeros({batch, len_x, len_y, len_z}, X.options()); |
| 60 | + |
| 61 | + int iters = batch / mini_batch; |
| 62 | + if (batch % mini_batch != 0) |
| 63 | + { |
| 64 | + iters += 1; |
| 65 | + } |
| 66 | + |
| 67 | + for (int i=0; i<iters; ++i) |
| 68 | + { |
| 69 | + int mini_batch_ = mini_batch; |
| 70 | + if ((i+1) * mini_batch > batch) |
| 71 | + { |
| 72 | + mini_batch_ = batch - i * mini_batch; |
| 73 | + } |
| 74 | + // printf("iter: %d %d %d %d %d \n", i, iters, mini_batch_, mini_batch, batch); |
| 75 | + AT_DISPATCH_FLOATING_TYPES(X.type(), "hist_cuda_core", ([&] { |
| 76 | + hist_cuda_core(at::cuda::getCurrentCUDAStream(), |
| 77 | + X.data<scalar_t>() + i*mini_batch*num_X*dim, |
| 78 | + Y.data<scalar_t>() + i*mini_batch*num_Y*dim, |
| 79 | + mini_batch_, dim, num_X, num_Y, |
| 80 | + min_x, min_y, min_z, |
| 81 | + max_x, max_y, max_z, |
| 82 | + len_x, len_y, len_z, |
| 83 | + bins.data<scalar_t>()+i*mini_batch*len_x*len_y*len_z); |
| 84 | + })); |
| 85 | + } |
| 86 | + |
| 87 | + |
| 88 | + |
| 89 | + return bins; |
| 90 | +} |
0 commit comments