import math
from typing import List, Tuple, Union
import torch
from torch import nn
from torchsparse import SparseTensor
__all__ = [
'ToBEVConvolution', 'ToBEVReduction', 'ToDenseBEVConvolution',
'ToBEVHeightCompression'
]
[docs]class ToBEVReduction(nn.Module):
def __init__(self, dim: int = 1) -> None:
super().__init__()
self.dim = dim
[docs] def forward(self, input: SparseTensor) -> SparseTensor:
coords, feats, stride = input.coords, input.feats, input.stride
coords = coords.clone()
coords[:, self.dim] = 0
feats = torch.cat([torch.ones_like(feats[:, :1]), feats], axis=1)
tensor = torch.sparse_coo_tensor(coords.t().long(), feats).to_dense()
coords = tensor.indices().t().int()
feats = tensor.values()[:, 1:] / tensor.values()[:, :1]
return SparseTensor(coords=coords, feats=feats, stride=stride)
[docs]class ToDenseBEVConvolution(nn.Module):
"""Converts a SparseTensor into a dense BEV feature map.
Group points with the same z value together and apply the same FC kernel.
Aggregate the results by summing up all features within one BEV grid.
Note:
This module consumes larger memory than ``ToBEVHeightCompression``.
Args:
in_channels: Number of input channels
out_channels: Number of output channels
shape: Shape of BEV map
dim: Dimension index for z (default: 1 for KITTI coords)
bias: Whether to use bias
"""
def __init__(self,
in_channels: int,
out_channels: int,
shape: Union[List[int], Tuple[int, int, int], torch.Tensor],
offset: Tuple[int, int, int] = (0, 0, 0),
dim: int = 1,
bias: bool = False) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.register_buffer('offset', torch.IntTensor([list(offset) + [0]]))
if isinstance(shape, torch.Tensor):
self.register_buffer('shape', shape.int())
else:
self.register_buffer('shape', torch.IntTensor(shape))
self.dim = dim
self.n_kernels = int(self.shape[self.dim])
self.bev_dims = [i for i in range(3) if i != self.dim]
self.bev_shape = self.shape[self.bev_dims]
self.kernel = nn.Parameter(
torch.zeros(self.n_kernels, in_channels, out_channels))
self.bias = nn.Parameter(torch.zeros(1, out_channels)) if bias else 0
self.reset_parameters()
[docs] def reset_parameters(self):
std = 1. / math.sqrt(self.in_channels)
self.kernel.data.uniform_(-std, std)
[docs] def forward(self, input: SparseTensor) -> torch.Tensor:
coords, feats, stride = input.coords, input.feats, input.stride
stride = torch.tensor(stride).unsqueeze(dim=0).to(feats)[:, self.dim]
kernel = torch.index_select(
self.kernel, 0,
torch.div(coords[:, self.dim], stride).trunc().long())
feats = (feats.unsqueeze(dim=-1) * kernel).sum(1) + self.bias
coords = (coords - self.offset).t()[[3] + self.bev_dims].long()
coords[1:] = torch.div(coords[1:], stride).trunc().long()
indices = coords[0] * int(self.bev_shape.prod()) + coords[1] * int(
self.bev_shape[1]) + coords[2]
batch_size = coords[0].max().item() + 1
output = torch.sparse_coo_tensor(
indices.unsqueeze(dim=0),
feats,
torch.Size(
[batch_size * int(self.bev_shape.prod()),
feats.size(-1)]),
).to_dense()
output = output.view(batch_size, *self.bev_shape, -1)
output = output.permute(0, 3, 1, 2).contiguous()
return output
[docs]class ToBEVConvolution(nn.Module):
"""Converts a SparseTensor into a sparse BEV feature map."""
def __init__(self,
in_channels: int,
out_channels: int,
n_kernels: int,
stride: int = 1,
dim: int = 1,
bias: bool = False) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.n_kernels = n_kernels
self.stride = stride
self.dim = dim
self.kernel = nn.Parameter(
torch.zeros(n_kernels, in_channels, out_channels))
self.bias = nn.Parameter(torch.zeros(1, out_channels)) if bias else 0
self.reset_parameters()
[docs] def reset_parameters(self):
std = 1. / math.sqrt(self.in_channels)
self.kernel.data.uniform_(-std, std)
[docs] def forward(self, input: SparseTensor) -> torch.Tensor:
coords, feats, stride = input.coords, input.feats, input.stride
ratio = stride * self.stride
stride = torch.tensor(stride).unsqueeze(dim=0).to(feats)[:, self.dim]
kernels = torch.index_select(
self.kernel, 0,
torch.div(coords[:, self.dim].long(), stride).trunc())
feats = (feats.unsqueeze(dim=-1) * kernels).sum(1) + self.bias
coords = coords.t().long()
coords[self.dim, :] = 0
if self.stride > 1:
coords[:3] /= ratio
coords[:3] *= ratio
flatten = torch.sparse_coo_tensor(coords, feats).to_dense()
return SparseTensor(flatten.values(),
flatten.indices().t().int(), ratio)
[docs]class ToBEVHeightCompression(nn.Module):
"""Converts a SparseTensor to a flattened volumetric tensor.
Args:
channels: Number of input channels
(Note: output channels = channels x #unique z values)
shape: Shape of BEV map
dim: Dimension index for z (default: 1 for KITTI coords)
"""
def __init__(self,
channels: int,
shape: Union[List[int], Tuple[int, int, int], torch.Tensor],
offset: Tuple[int, int, int] = (0, 0, 0),
dim: int = 1) -> None:
super().__init__()
self.channels = channels
self.register_buffer('offset', torch.IntTensor([list(offset) + [0]]))
if isinstance(shape, torch.Tensor):
self.register_buffer('shape', shape.int())
else:
self.register_buffer('shape', torch.IntTensor(shape))
self.dim = dim
self.bev_dims = [i for i in range(3) if i != self.dim]
self.bev_shape = self.shape[self.bev_dims]
[docs] def forward(self, input: SparseTensor) -> torch.Tensor:
coords, feats, stride = input.coords, input.feats, input.stride
stride = torch.tensor(stride).unsqueeze(dim=0).to(coords.device)
assert isinstance(stride, torch.Tensor), type(stride)
# [b, x, y, z]
coords = (coords - self.offset).t()[[3] + self.bev_dims
+ [self.dim]].long()
shape = self.shape[self.bev_dims + [self.dim]]
# now stride must be torch.Tensor since input.s is tuple.
stride = stride[:, self.bev_dims + [self.dim]].t()
coords[1:] = torch.div(coords[1:], stride).trunc().long()
coords[-1] = torch.clamp(coords[-1], 0, shape[-1] - 1)
indices = coords[0] * int(shape.prod()) + coords[1] * int(
shape[1:].prod()) + coords[2] * int(shape[2]) + coords[3]
batch_size = coords[0].max().item() + 1
output = torch.sparse_coo_tensor(
indices.unsqueeze(dim=0),
feats,
torch.Size([batch_size * int(self.shape.prod()),
feats.size(-1)]),
).to_dense()
output = output.view(batch_size, *self.bev_shape.cpu().numpy(), -1)
output = output.permute(0, 3, 1, 2).contiguous()
return output