Source code for torchsparse.nn.functional.build_kmap

from typing import Tuple, Union

import torch

import torchsparse.backend
from torchsparse.nn import functional as F
from torchsparse.nn.utils import get_kernel_offsets
from torchsparse.utils import make_ntuple

__all__ = ['build_kernel_map']


[docs]def build_kernel_map(_coords: torch.Tensor, kernel_size: Union[int, Tuple[int, ...]] = 2, stride: Union[int, Tuple[int, ...]] = 2, tensor_stride: Union[int, Tuple[int, ...]] = 1, mode='hashmap') -> torch.Tensor: if mode == 'grid': coords = _coords[:, [3, 0, 1, 2]] stride = make_ntuple(stride, ndim=3) kernel_size = make_ntuple(kernel_size, ndim=3) tensor_stride = make_ntuple(tensor_stride, ndim=3) subm = not (any(s > 1 for s in stride)) stride = torch.tensor(stride, dtype=torch.int, device=coords.device) kernel_size = torch.tensor(kernel_size, dtype=torch.int, device=coords.device) tensor_stride = torch.tensor(tensor_stride, dtype=torch.int, device=coords.device) if subm: func = torchsparse.backend.build_kernel_map_subm else: func = torchsparse.backend.build_kernel_map_downsample out = func(coords, coords.min(0).values, coords.max(0).values, kernel_size, stride, tensor_stride) nbmaps = out[0] input_mask, output_mask = out[-2:] if len(out) == 4: return out else: return tuple(out[:2]) + (out[2][:, [1, 2, 3, 0]],) + tuple(out[3:]) else: offsets = get_kernel_offsets(kernel_size, stride=tensor_stride, device=_coords.device) references = F.sphash(_coords) kernel_size = make_ntuple(kernel_size, ndim=3) stride = make_ntuple(stride, ndim=3) if any(s > 1 for s in stride): coords = F.spdownsample(_coords, stride, kernel_size, tensor_stride) else: coords = _coords queries = F.sphash(coords, offsets) results = F.sphashquery(queries, references) nbsizes = torch.sum(results != -1, dim=1) nbmaps = torch.nonzero(results != -1) nbmaps[:, 0] = results.view(-1)[nbmaps[:, 0] * results.size(1) + nbmaps[:, 1]] # important for build masks nbmaps = nbmaps.contiguous() input_mask, output_mask = torchsparse.backend.build_mask_from_kmap( _coords.shape[0], coords.shape[0], nbmaps.int(), nbsizes.int()) if any(s > 1 for s in stride): return nbmaps, nbsizes, coords, input_mask, output_mask else: return nbmaps, nbsizes, input_mask, output_mask