PyTorch的矩阵乘法
# 主要区别总结:
# @ 和 torch.matmul(): 支持广播,适用于各种维度
# torch.mm(): 仅支持2D矩阵,不广播
# *: 逐元素乘法,要求形状相同
# torch.dot(): 仅支持1D向量的点积
# 推荐使用 @ 运算符,因为它最直观且功能全面
import torch
#矩阵乘法@
x = torch.tensor([1,2,3,4,5,6]).reshape(2,3)
y = torch.tensor([1,2,3,4,5,6]).reshape(3,2)
out=x @ y
print(x)
print(y)
print(out)
# tensor([[1, 2, 3],
# [4, 5, 6]])
# tensor([[1, 2],
# [3, 4],
# [5, 6]])
# tensor([[22, 28],
# [49, 64]])
#矩阵乘法torch.matmul()
out=torch.matmul(x, y)
print(out)
# tensor([[22, 28],
# [49, 64]])
# Element-wise元素级别的乘法
x = torch.tensor([1,2,3,4]).reshape(2, 2)
y = torch.tensor([2,2,2,2]).reshape(2, 2)
element_wise= x*y
print(x)
print(y)
print(element_wise)
# tensor([[1, 2],
# [3, 4]])
# tensor([[2, 2],
# [2, 2]])
# tensor([[2, 4],
# [6, 8]])
#后两维做矩阵乘法
batch_x = torch.randn(3, 2, 3)
batch_y = torch.randn(3, 3, 2)
print(batch_x)
print(batch_y)
out = torch.matmul(batch_x, batch_y)
print(out)
#验证上面的代码(batchsize相当于,batch_x的第一个维度相当于for循环(批量数据处理))
x = torch.arange(1, 19).reshape(3, 2, 3)
y = torch.arange(1, 19).reshape(3, 3, 2)
print(x)
print(y)
out = torch.matmul(x, y)
print(out)
# tensor([[[ 1, 2, 3],
# [ 4, 5, 6]],
# [[ 7, 8, 9],
# [10, 11, 12]],
# [[13, 14, 15],
# [16, 17, 18]]])
# tensor([[[ 1, 2],
# [ 3, 4],
# [ 5, 6]],
# [[ 7, 8],
# [ 9, 10],
# [11, 12]],
# [[13, 14],
# [15, 16],
# [17, 18]]])
# tensor([[[ 22, 28],
# [ 49, 64]],
# [[220, 244],
# [301, 334]],
# [[634, 676],
# [769, 820]]])
x = torch.arange(1, 7).reshape(2, 3)
y = torch.arange(1, 7).reshape(3, 2)
print(x)
print(y)
out = torch.matmul(x, y)
print(out)
# tensor([[1, 2, 3],
# [4, 5, 6]])
# tensor([[1, 2],
# [3, 4],
# [5, 6]])
# tensor([[22, 28],
# [49, 64]])