PyTorch中topk函数的用法详解

PyTorch中topk函数的用法详解

今天边肖就给大家详细讲解一下PyTorch中topk函数的用法,有很好的参考价值。希望对你有帮助。来和边肖一起看看吧。

从名字就能看出来,这个函数是用来求张量中一个dim的顶k或底k的值以及对应的索引的。

用法

torch.topk(input,k,dim=None,maximum=True,sorted=True,out=None) - (Tensor,LongTensor)

输入:一个张量数据

k:表示获得第k个数据及其索引。

Dim:指定要排序的维,默认为最后一个维。

最大:如果为真,则按从大到小的顺序排序;如果为False,则从最小到最大排序。

排序:返回的结果按顺序返回。

Out:默认,不要

topk最常见的场合是找到一个样本,该样本被网络认为是前K个最可能的类别。让我们以这个场景为例来说明函数的使用。

假设一,n是样本数,一般等于批量,d是类别数。我们想知道每个样本最有可能属于哪个类别,但是我们实际上可以用torch.max得到如果你想用topk,那么K应该设置为1。

进口火炬

pred=torch.randn((4,5))

打印(预测)

values,indexes=pred . topk(1,dim=1,maximum=True,sorted=True)

打印(索引)

#使用max获得的结果,将keepdim设置为True以避免降维。因为topk函数返回的索引没有降维,所以形状与输入一致。

_,indexes _ max=pred . max(dim=1,keepdim=True)

print(indexes _ max==indexes)

#预测

张量([[-0.1480,-0.9819,-0.3364,0.7912,-0.3263),

[-0.8013, -0.9083, 0.7973, 0.1458, -0.9156],

[-0.2334, -0.0142, -0.5493, 0.0673, 0.8185],

[-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]])

# indices,shape shape是[4,1],

张量([[3],# [0,0])表示第一个样本最有可能属于第一类。

[2],# [1,0]表示第二个样本最有可能属于第二类。

[4],

[2]])

# indices _ max等于索引

张量([[真],

[真],

[真],

[真]])

现在在尝试一下k=2

进口火炬

pred=torch.randn((4,5))

打印(预测)

values,indexes=pred . topk(2,dim=1,maximum=True,sorted=True) # k=2

打印(索引)

#预测

张量([[-0.2203,-0.7538,1.8789,0.4451,-0.2526),

[-0.0413, 0.6366, 1.1155, 0.3484, 0.0395],

[ 0.0365, 0.5158, 1.1067, -0.9276, -0.2124],

[ 0.6232, 0.9912, -0.8562, 0.0148, 1.6413]])

#索引

张量([[2,3],

[2, 1],

[2, 1],

[4, 1]])

可以发现,指数的形状变成[4,k],k=2。

其中索引[0]=[2,3]。意味着第一个样本的前两个最大概率分别对应类别3和类别4。

您可以打印自己的值。可以发现,值的形状与指数的形状相同。Indices描述了相应值在pred值中的位置。

以上对PyTorch中topk函数用法的详细说明,就是边肖分享的全部内容。希望给大家一个参考,多多支持我们。

PyTorch中topk函数的用法详解