Source code for torchsparse.nn.modules.norm

import torch
from torch import nn

from torchsparse import SparseTensor
from torchsparse.nn.utils import fapply

__all__ = ['BatchNorm', 'GroupNorm']


[docs]class BatchNorm(nn.BatchNorm1d): """ Batch normalization layer. """
[docs] def forward(self, input: SparseTensor) -> SparseTensor: return fapply(input, super().forward)
[docs]class GroupNorm(nn.GroupNorm): """ Group normalization layer. """
[docs] def forward(self, input: SparseTensor) -> SparseTensor: coords, feats, stride = input.coords, input.feats, input.stride batch_size = torch.max(coords[:, -1]).item() + 1 num_channels = feats.shape[1] # PyTorch's GroupNorm function expects the input to be in (N, C, *) # format where N is batch size, and C is number of channels. "feats" # is not in that format. So, we extract the feats corresponding to # each sample, bring it to the format expected by PyTorch's GroupNorm # function, and invoke it. nfeats = torch.zeros_like(feats) for k in range(batch_size): indices = coords[:, -1] == k bfeats = feats[indices] bfeats = bfeats.transpose(0, 1).reshape(1, num_channels, -1) bfeats = super().forward(bfeats) bfeats = bfeats.reshape(num_channels, -1).transpose(0, 1) nfeats[indices] = bfeats output = SparseTensor(coords=coords, feats=nfeats, stride=stride) output.cmaps = input.cmaps output.kmaps = input.kmaps return output