函数原型:
torch.nonzero(input, out=None) → LongTensor
参数:
input (Tensor) – 源张量
out (LongTensor, optional) – 包含索引值的结果张量
代码示例
返回一个包含输入input中非零元素索引的张量。输出张量中的每行包含输入中非零元素的索引。
x = torch.tensor([0, 0, 1, 5, 8])
y = torch.nonzero(x)
print(y)
print(y.shape)
>>>
tensor([[2],[3],[4]])
torch.Size([3, 1])
如果输入input有n维,则输出的索引张量output的形状为 z x n, 这里 z 是输入张量input中所有非零元素的个数。
x = torch.tensor([[0, 0, 1, 5],[1, 5, 0, 8],[2, 8, 9, 0]])
y = torch.nonzero(x)
print(y)
print(y.shape)
>>>
tensor([[0, 2],[0, 3],[1, 0],[1, 1],[1, 3],[2, 0],[2, 1],[2, 2]])
torch.Size([8, 2])
pytorch文档学习链接:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch/