from typing import Any, List
import numpy as np
import torch
from torchsparse import SparseTensor
__all__ = ['sparse_collate', 'sparse_collate_fn']
[docs]def sparse_collate(inputs: List[SparseTensor]) -> SparseTensor:
"""Assemble a batch of sparse tensors and add the batch dimension to coords.
Args:
inputs (List[SparseTensor]): A list of sparse tensors.
Returns:
SparseTensor: A batch of collated sparse tensors.
"""
coords, feats = [], []
stride = inputs[0].stride
for k, x in enumerate(inputs):
if isinstance(x.coords, np.ndarray):
x.coords = torch.tensor(x.coords)
if isinstance(x.feats, np.ndarray):
x.feats = torch.tensor(x.feats)
assert isinstance(x.coords, torch.Tensor), type(x.coords)
assert isinstance(x.feats, torch.Tensor), type(x.feats)
assert x.stride == stride, (x.stride, stride)
input_size = x.coords.shape[0]
batch = torch.full((input_size, 1),
k,
device=x.coords.device,
dtype=torch.int)
coords.append(torch.cat((x.coords, batch), dim=1))
feats.append(x.feats)
coords = torch.cat(coords, dim=0)
feats = torch.cat(feats, dim=0)
output = SparseTensor(coords=coords, feats=feats, stride=stride)
return output
[docs]def sparse_collate_fn(inputs: List[Any]) -> Any:
"""Access the sparse tensors in the input list and call sparse_collate.
Args:
inputs (List[Any]): A list of inputs.
Returns:
Any: inputs with the sparse tensors collated.
"""
if isinstance(inputs[0], dict):
output = {}
for name in inputs[0].keys():
if isinstance(inputs[0][name], dict):
output[name] = sparse_collate_fn(
[input[name] for input in inputs])
elif isinstance(inputs[0][name], np.ndarray):
output[name] = torch.stack(
[torch.tensor(input[name]) for input in inputs], dim=0)
elif isinstance(inputs[0][name], torch.Tensor):
output[name] = torch.stack([input[name] for input in inputs],
dim=0)
elif isinstance(inputs[0][name], SparseTensor):
output[name] = sparse_collate([input[name] for input in inputs])
else:
output[name] = [input[name] for input in inputs]
return output
else:
return inputs