Source code for torchsparse.nn.functional.query

import torch

import torchsparse.backend

__all__ = ['sphashquery']


[docs]def sphashquery(queries: torch.Tensor, references: torch.Tensor) -> torch.Tensor: queries = queries.contiguous() references = references.contiguous() sizes = queries.size() queries = queries.view(-1) indices = torch.arange(len(references), device=queries.device, dtype=torch.long) if queries.device.type == 'cuda': output = torchsparse.backend.hash_query_cuda(queries, references, indices) elif queries.device.type == 'cpu': output = torchsparse.backend.hash_query_cpu(queries, references, indices) else: device = queries.device output = torchsparse.backend.hash_query_cpu(queries.cpu(), references.cpu(), indices.cpu()).to(device) output = (output - 1).view(*sizes) return output