pytorch矩阵点乘,pytorch tensor相乘
在本文中,目录点乘以` torch.mul(a,b)`二维矩阵乘以` torch.mm(a,b)`三维矩阵乘以` torch.bmm(`a,b)`高维矩阵乘以` torch.matmul(a,b)。
点乘火炬. mul(a,b)
点乘是相应位置元素的乘法。
点乘是广播式的,可以用torch.mul(a,b)实现,也可以直接用*实现。
python中的广播机制
广播可以这样理解:如果你有一个(m,n)的矩阵,让它对一个(1,n)的矩阵进行加减乘除,它会被复制m次成为(m,n)的矩阵,然后再对它进行元素间的加减乘除。类似地,接地对矩阵(m,1)成立。
资料来源:https://www.jianshu.com/p/fadd169cd396
当维度A和B满足广播机制时,会自动填充到同维度相乘中。
比如A的维数是(2,3),B的维数是(1,3);
或者:A的维数是(2,3),B的维数是(2,1)。当A和B的维数不满足广播机制时,要求A和B的维数必须相等。
如果A的维数为(1,2),B的维数为(2,3),则会报错:张量A (2)的大小必须与张量B (3)在非奇异维数1处的大小相匹配。
报错是指B中维数为3的位置必须与A中维数为2的位置匹配,因为A中有一个维数为1,必须为(1,2)和(2,2)才能满足广播机制,否则需要满足维数必须等于(2,3)和(2,3)二维矩阵乘以torch.mm (a,b) torch。
二维矩阵乘法要求参数A和B的维数满足乘法要求。
该函数一般只用于计算两个二维矩阵的矩阵乘法,不支持广播运算。
提供三维矩阵乘以torch.bmm(a,b)是因为神经网络训练一般采用mini-batch,经常输入带batch的三维矩阵。
torch.bmm(bmat1,bmat2,out=无)
该函数的两个输入必须是三维矩阵,第一个维度相同(代表批量维度),后两个维度满足矩阵乘法的要求。不支持广播操作
高维矩阵的最后两个维度乘以torch。手电筒。matmul (input,other,out=none)满足矩阵乘法的要求,前面的维度被认为是batch_size,使用广播机制。
主要参考资料:
《pytorch的火炬》中的几个乘法
pytorch中矩阵乘法的总结