Source code for torchsparse.nn.functional.crop

from typing import Optional, Tuple

import torch

from torchsparse import SparseTensor

__all__ = ['spcrop']


[docs]def spcrop(input: SparseTensor, coords_min: Optional[Tuple[int, ...]] = None, coords_max: Optional[Tuple[int, ...]] = None) -> SparseTensor: coords, feats, stride = input.coords, input.feats, input.stride mask = torch.ones((coords.shape[0], 3), dtype=torch.bool, device=coords.device) if coords_min is not None: coords_min = torch.tensor(coords_min, dtype=torch.int, device=coords.device).unsqueeze(dim=0) mask &= (coords[:, :3] >= coords_min) if coords_max is not None: coords_max = torch.tensor(coords_max, dtype=torch.int, device=coords.device).unsqueeze(dim=0) # Using "<" instead of "<=" is for the backward compatability (in # some existing detection codebase). We might need to reflect this # in the document or change it back to "<=" in the future. mask &= (coords[:, :3] < coords_max) mask = torch.all(mask, dim=1) coords, feats = coords[mask], feats[mask] output = SparseTensor(coords=coords, feats=feats, stride=stride) return output