Source code for torchsparse.nn.modules.pooling

import torch
from torch import nn

from torchsparse import SparseTensor
from torchsparse.nn import functional as F

__all__ = ['GlobalAvgPool', 'GlobalMaxPool']


[docs]class GlobalAvgPool(nn.Module): """ Global average pooling layer. """
[docs] def forward(self, input: SparseTensor) -> torch.Tensor: return F.global_avg_pool(input)
[docs]class GlobalMaxPool(nn.Module): """ Global max pooling layer. """
[docs] def forward(self, input: SparseTensor) -> torch.Tensor: return F.global_max_pool(input)