PyTorch的矩阵乘法与线性变换
#torch.nn下的Linear是线性变换,包括矩阵乘法和偏置量
import torch
import torch.nn as nn
batch_size = 4
input_features = 4
output_features =3
x=torch.randn(batch_size, input_features)
linear_layer = nn.Linear(input_features, output_features) #自动生成随机的权重矩阵weight,和偏置bias
out_linear = linear_layer(x)
#使用matmul
weight = linear_layer.weight
bias = linear_layer.bias
out_matmul = x @ weight.T +bias
print(out_linear)
print(out_matmul)
print("Is equal=", torch.allclose(out_linear, out_matmul))
# tensor([[ 0.1254, 0.6651, -0.3266],
# [ 0.0621, 0.8397, 0.6082],
# [-0.5628, 0.0811, 0.4234],
# [ 0.6737, 0.0293, -0.2010]], grad_fn=)
# tensor([[ 0.1254, 0.6651, -0.3266],
# [ 0.0621, 0.8397, 0.6082],
# [-0.5628, 0.0811, 0.4234],
# [ 0.6737, 0.0293, -0.2010]], grad_fn=)
# Is equal= True