from typing import Optional, Tuple
import torch
from torchsparse import SparseTensor
__all__ = ['spcrop']
[docs]def spcrop(input: SparseTensor,
coords_min: Optional[Tuple[int, ...]] = None,
coords_max: Optional[Tuple[int, ...]] = None) -> SparseTensor:
coords, feats, stride = input.coords, input.feats, input.stride
mask = torch.ones((coords.shape[0], 3),
dtype=torch.bool,
device=coords.device)
if coords_min is not None:
coords_min = torch.tensor(coords_min,
dtype=torch.int,
device=coords.device).unsqueeze(dim=0)
mask &= (coords[:, :3] >= coords_min)
if coords_max is not None:
coords_max = torch.tensor(coords_max,
dtype=torch.int,
device=coords.device).unsqueeze(dim=0)
# Using "<" instead of "<=" is for the backward compatability (in
# some existing detection codebase). We might need to reflect this
# in the document or change it back to "<=" in the future.
mask &= (coords[:, :3] < coords_max)
mask = torch.all(mask, dim=1)
coords, feats = coords[mask], feats[mask]
output = SparseTensor(coords=coords, feats=feats, stride=stride)
return output