Source code for torchsparse.nn.functional.conv

from typing import List, Optional, Tuple, Union

import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd

import torchsparse
import torchsparse.backend
from torchsparse import SparseTensor
from torchsparse.nn import functional as F
from torchsparse.utils import make_ntuple

buffer = torch.Tensor()

__all__ = ['conv3d']


class ConvolutionFunction(Function):

    @staticmethod
    @custom_fwd(cast_inputs=torch.half)
    def forward(
        ctx,
        input: torch.Tensor,
        weight: torch.Tensor,
        nbmaps: torch.Tensor,
        nbsizes: torch.Tensor,
        buffer: torch.Tensor,
        sizes: Tuple[int, int],
        input_mask: torch.Tensor,
        output_mask: torch.Tensor,
        epsilon: float,
        mm_thresh: int,
        conv_mode: int,
        transposed: bool = False,
    ) -> torch.Tensor:
        input = input.contiguous()
        weight = weight.contiguous()
        nbmaps = nbmaps.int().contiguous()
        nbsizes = nbsizes.int().contiguous()

        if not input.device.type == 'cuda':
            if not transposed:
                output = torch.zeros(sizes[1],
                                     weight.size(-1),
                                     dtype=input.dtype,
                                     device=input.device)
            else:
                # TODO(Haotian): ensure the original, upsampled size to be the same.
                output = torch.zeros(sizes[0],
                                     weight.size(-1),
                                     dtype=input.dtype,
                                     device=input.device)

        if input.device.type == 'cuda':
            output = torchsparse.backend.convolution_forward_cuda(
                input,
                weight,
                nbmaps,
                nbsizes.cpu(),
                input_mask,
                output_mask,
                sizes[1] if not transposed else sizes[0],
                epsilon,
                int(mm_thresh),
                conv_mode,
                transposed,
                buffer,
            )
        elif input.device.type == 'cpu':
            torchsparse.backend.convolution_forward_cpu(input, output, weight,
                                                        nbmaps, nbsizes.cpu(),
                                                        transposed)
        else:
            # use the native pytorch XLA APIs for the TPU.
            cur_st = 0
            for kernel_idx in range(weight.shape[0]):
                cur_ed = cur_st + nbsizes[kernel_idx]
                in_map = nbmaps[cur_st:cur_ed, 0].long()
                out_map = nbmaps[cur_st:cur_ed, 1].long()
                cur_st += nbsizes[kernel_idx]

                if transposed:
                    in_map, out_map = out_map, in_map

                cur_feat = input[in_map]
                cur_feat = torch.mm(cur_feat, weight[kernel_idx])
                output[out_map] += cur_feat
        ctx.for_backwards = (input, weight, nbmaps, nbsizes, transposed)
        return output

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output: torch.Tensor):
        input, weight, nbmaps, nbsizes, transposed = ctx.for_backwards

        grad_input = torch.zeros_like(input)
        grad_weight = torch.zeros_like(weight)

        if grad_output.device.type == 'cuda':
            torchsparse.backend.convolution_backward_cuda(
                input,
                grad_input,
                grad_output.contiguous(),
                weight,
                grad_weight,
                nbmaps,
                nbsizes.cpu(),
                transposed,
            )
        elif grad_output.device.type == 'cpu':
            torchsparse.backend.convolution_backward_cpu(
                input,
                grad_input,
                grad_output.contiguous(),
                weight,
                grad_weight,
                nbmaps,
                nbsizes.cpu(),
                transposed,
            )
        else:
            raise NotImplementedError
        return (
            grad_input,
            grad_weight,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


[docs]def conv3d( input: SparseTensor, weight: torch.Tensor, kernel_size: Union[int, List[int], Tuple[int, ...]], bias: Optional[torch.Tensor] = None, stride: Union[int, List[int], Tuple[int, ...]] = 1, dilation: Union[int, Tuple[int, ...]] = 1, transposed: bool = False, epsilon: float = 0.0, mm_thresh: int = 0, kmap_mode: str = 'hashmap', ) -> SparseTensor: feats, coords = input.feats, input.coords kernel_size = make_ntuple(kernel_size, ndim=3) stride = make_ntuple(stride, ndim=3) dilation = make_ntuple(dilation, ndim=3) conv_mode_num = 0 global buffer if torchsparse.backends.benchmark: # type: ignore conv_mode_num = 1 if (epsilon == 0.0 and mm_thresh == 0) else 2 if buffer.shape[0] == 0 or buffer.dtype != input.F.dtype: buffer = torch.zeros(4000000 * 64, dtype=input.F.dtype, device=input.F.device, requires_grad=False) if kernel_size == (1, 1, 1) and stride == (1, 1, 1) and dilation == (1, 1, 1): feats = feats.matmul(weight) if bias is not None: feats += bias output = SparseTensor(coords=coords, feats=feats, stride=input.stride) elif not transposed: kmap = input.kmaps.get((input.stride, kernel_size, stride, dilation)) if kmap is None: if any(s > 1 for s in stride): kmap_out = F.build_kernel_map(coords, kernel_size, stride, input.stride, kmap_mode) if len(kmap_out) == 3: nbmaps, nbsizes, coords = kmap_out input_mask, output_mask = None, None elif len(kmap_out) == 5: nbmaps, nbsizes, coords, input_mask, output_mask = kmap_out else: raise NotImplementedError else: kmap_out = F.build_kernel_map(coords, kernel_size, stride, input.stride, kmap_mode) if len(kmap_out) == 2: nbmaps, nbsizes = kmap_out input_mask, output_mask = None, None elif len(kmap_out) == 4: nbmaps, nbsizes, input_mask, output_mask = kmap_out else: raise NotImplementedError nbsizes = nbsizes.cpu() kmap = [ nbmaps, nbsizes, (feats.shape[0], coords.shape[0]), input_mask, output_mask, ] input.kmaps[(input.stride, kernel_size, stride, dilation)] = kmap feats = ConvolutionFunction.apply( feats, weight, kmap[0], kmap[1], buffer, kmap[2], kmap[3], kmap[4], epsilon, mm_thresh, conv_mode_num, transposed, ) if bias is not None: feats += bias output = SparseTensor( coords=coords, feats=feats, stride=tuple(input.stride[k] * stride[k] for k in range(3)), ) else: tensor_stride = tuple(input.stride[k] // stride[k] for k in range(3)) kmap = input.kmaps[(tensor_stride, kernel_size, stride, dilation)] feats = ConvolutionFunction.apply( feats, weight, kmap[0], kmap[1], buffer, kmap[2], kmap[3], kmap[4], epsilon, mm_thresh, conv_mode_num, transposed, ) if bias is not None: feats += bias output = SparseTensor( coords=input.cmaps[tensor_stride], feats=feats, stride=tensor_stride, ) output.cmaps = input.cmaps output.cmaps.setdefault(output.stride, output.coords) output.kmaps = input.kmaps return output