import math
import sys
from typing import Dict, List, Tuple, Union
if sys.version_info >= (3, 8):
from functools import cached_property
else:
from backports.cached_property import cached_property
import numpy as np
import torch
from torch import nn
import torchsparse
from torchsparse import SparseTensor
from torchsparse.nn import functional as F
from torchsparse.utils import make_ntuple
__all__ = ['Conv3d']
[docs]class Conv3d(nn.Module):
"""3D convolution layer for a sparse tensor.
Args:
in_channels (int): Number of channels in the input sparse tensor.
out_channels (int): Number of channels in the output sparse tensor.
kernel_size (int or tuple): Size of the 3D convolving kernel.
stride (int or tuple): Stride of the convolution. Default: 1.
dilation (int or tuple): Spacing between kernel elements. Default: 1.
bias (bool): If True, add a learnable bias to the output. Default: True.
transposed (bool): If True, use transposed convolution. Default: False.
config (dict): The 3D convolution configuration, which includes the
``kmap_mode`` (hashmap or grid), and ``epsilon`` (redundant
computation tolerance) and ``mm_thresh`` (mm/bmm threshold) when
using the adaptive matmul grouping. Default: None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, List[int], Tuple[int, ...]] = 3,
stride: Union[int, List[int], Tuple[int, ...]] = 1,
dilation: int = 1,
bias: bool = False,
transposed: bool = False,
config: Dict = None) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = make_ntuple(kernel_size, ndim=3)
self.stride = make_ntuple(stride, ndim=3)
self.dilation = dilation
self.transposed = transposed
if config is None:
config = {}
config['epsilon'] = config.get('epsilon', 0.0)
config['mm_thresh'] = config.get('mm_thresh', 0)
config['kmap_mode'] = config.get('kmap_mode', 'hashmap')
self.config = config
self.kernel_volume = int(np.prod(self.kernel_size))
if self.kernel_volume > 1:
self.kernel = nn.Parameter(
torch.zeros(self.kernel_volume, in_channels, out_channels))
else:
self.kernel = nn.Parameter(torch.zeros(in_channels, out_channels))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
[docs] def reset_parameters(self) -> None:
std = 1 / math.sqrt(
(self.out_channels if self.transposed else self.in_channels)
* self.kernel_volume)
self.kernel.data.uniform_(-std, std)
if self.bias is not None:
self.bias.data.uniform_(-std, std)
@cached_property
def _reordered_kernel(self) -> nn.Parameter:
kernel_data = torch.zeros_like(self.kernel.data)
ind = 0
while ind < self.kernel_volume - 1:
kernel_data[ind] = self.kernel.data[ind // 2].clone()
kernel_data[ind + 1] = \
self.kernel.data[self.kernel_volume - 1 - ind // 2].clone()
ind += 2
if self.kernel_volume % 2 == 1:
kernel_data[self.kernel_volume - 1] = \
self.kernel.data[self.kernel_volume // 2].clone()
return nn.Parameter(kernel_data, requires_grad=False)
[docs] def forward(self, input: SparseTensor) -> SparseTensor:
kernel = self.kernel
epsilon, mm_thresh = self.config['epsilon'], self.config['mm_thresh']
if torchsparse.backends.benchmark: # type: ignore
if self.training:
print('Warning: it is not recommended to enable '
+ 'torchsparse.backends.benchmark during the training.')
epsilon, mm_thresh = 0.0, 0
elif (self.config['epsilon'] != 0.0
or self.config['mm_thresh'] != 0) and \
len(kernel.data.shape) == 3:
kernel = self._reordered_kernel
return F.conv3d(input,
kernel,
kernel_size=self.kernel_size,
bias=self.bias,
stride=self.stride,
dilation=self.dilation,
transposed=self.transposed,
epsilon=epsilon,
mm_thresh=mm_thresh,
kmap_mode=self.config['kmap_mode'])