Source code for torchsparse.nn.functional.pooling

import torch

from torchsparse import SparseTensor

__all__ = ['global_avg_pool', 'global_max_pool']


[docs]def global_avg_pool(inputs: SparseTensor) -> torch.Tensor: batch_size = torch.max(inputs.coords[:, -1]).item() + 1 outputs = [] for k in range(batch_size): input = inputs.feats[inputs.coords[:, -1] == k] output = torch.mean(input, dim=0) outputs.append(output) outputs = torch.stack(outputs, dim=0) return outputs
[docs]def global_max_pool(inputs: SparseTensor) -> torch.Tensor: batch_size = torch.max(inputs.coords[:, -1]).item() + 1 outputs = [] for k in range(batch_size): input = inputs.feats[inputs.coords[:, -1] == k] output = torch.max(input, dim=0)[0] outputs.append(output) outputs = torch.stack(outputs, dim=0) return outputs