Source code for torchsparse.nn.functional.downsample

from typing import Tuple, Union

import torch

import torchsparse.backend
from torchsparse.utils import make_ntuple

__all__ = ['spdownsample']


[docs]def spdownsample( coords: torch.Tensor, stride: Union[int, Tuple[int, ...]] = 2, kernel_size: Union[int, Tuple[int, ...]] = 2, tensor_stride: Union[int, Tuple[int, ...]] = 1) -> torch.Tensor: stride = make_ntuple(stride, ndim=3) kernel_size = make_ntuple(kernel_size, ndim=3) tensor_stride = make_ntuple(tensor_stride, ndim=3) sample_stride = [stride[k] * tensor_stride[k] for k in range(3)] sample_stride = torch.tensor(sample_stride, dtype=torch.int, device=coords.device).unsqueeze(dim=0) if all(stride[k] in [1, kernel_size[k]] for k in range(3)): coords = coords.clone() coords[:, :3] = torch.div( coords[:, :3], sample_stride.float()).trunc() * sample_stride # type: ignore coords = coords[:, [3, 0, 1, 2]] coords = torch.unique(coords, dim=0) coords = coords[:, [1, 2, 3, 0]] return coords else: if coords.device.type == 'cuda': coords = coords[:, [3, 0, 1, 2]] out_coords = torchsparse.backend.downsample_cuda( coords, coords.max(0).values, coords.min(0).values, kernel_size, stride, tensor_stride)[:, [1, 2, 3, 0]] return out_coords else: raise NotImplementedError