Source code for torchsparse.nn.functional.pooling
import torch
from torchsparse import SparseTensor
__all__ = ['global_avg_pool', 'global_max_pool']
[docs]def global_avg_pool(inputs: SparseTensor) -> torch.Tensor:
batch_size = torch.max(inputs.coords[:, -1]).item() + 1
outputs = []
for k in range(batch_size):
input = inputs.feats[inputs.coords[:, -1] == k]
output = torch.mean(input, dim=0)
outputs.append(output)
outputs = torch.stack(outputs, dim=0)
return outputs
[docs]def global_max_pool(inputs: SparseTensor) -> torch.Tensor:
batch_size = torch.max(inputs.coords[:, -1]).item() + 1
outputs = []
for k in range(batch_size):
input = inputs.feats[inputs.coords[:, -1] == k]
output = torch.max(input, dim=0)[0]
outputs.append(output)
outputs = torch.stack(outputs, dim=0)
return outputs