torch.gather scatter
torch.gather
import torch b = a.gather(dim, index) b = torch.gather(src, 1, torch.tensor([[0,0],[1,0]])) dim:指定轴方向,定义了填充方式。对于二维张量,dim=0表示逐列进行行填充,而dim=1表示逐列进行行填充。 * 当dim=1时,index[0][0]的元素是1,那么它想要查找a[0][1]中的元素; * 当dim=0时,index[0][0]的元素是1,那么它想查找的a[1][0]中的元素; index: 按照轴方向,在target张量中需要填充的位置
a = torch.arange(15).view(3, 5)
b = torch.zeros_like(a)
b[1][2] = 1
b[0][0] = 1
c = a.gather(0, b) # dim=0
d = a.gath
共有 0 条评论