Source code for torchsparse.backbones.modules.blocks

from typing import List, Tuple, Union

import numpy as np
from torch import nn

from torchsparse import SparseTensor
from torchsparse import nn as spnn

__all__ = ['SparseConvBlock', 'SparseConvTransposeBlock', 'SparseResBlock']


[docs]class SparseConvBlock(nn.Sequential): """ Sparse convolution block. """ def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, List[int], Tuple[int, ...]], stride: Union[int, List[int], Tuple[int, ...]] = 1, dilation: int = 1) -> None: super().__init__( spnn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation), spnn.BatchNorm(out_channels), spnn.ReLU(True), )
[docs]class SparseConvTransposeBlock(nn.Sequential): """ Sparse convolution transpose block. """ def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, List[int], Tuple[int, ...]], stride: Union[int, List[int], Tuple[int, ...]] = 1, dilation: int = 1) -> None: super().__init__( spnn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, transposed=True), spnn.BatchNorm(out_channels), spnn.ReLU(True), )
[docs]class SparseResBlock(nn.Module): """ Sparse residual block. """ def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, List[int], Tuple[int, ...]], stride: Union[int, List[int], Tuple[int, ...]] = 1, dilation: int = 1) -> None: super().__init__() self.main = nn.Sequential( spnn.Conv3d(in_channels, out_channels, kernel_size, dilation=dilation, stride=stride), spnn.BatchNorm(out_channels), spnn.ReLU(True), spnn.Conv3d(out_channels, out_channels, kernel_size, dilation=dilation), spnn.BatchNorm(out_channels), ) if in_channels != out_channels or np.prod(stride) != 1: self.shortcut = nn.Sequential( spnn.Conv3d(in_channels, out_channels, 1, stride=stride), spnn.BatchNorm(out_channels), ) else: self.shortcut = nn.Identity() self.relu = spnn.ReLU(True)
[docs] def forward(self, x: SparseTensor) -> SparseTensor: x = self.relu(self.main(x) + self.shortcut(x)) return x