import torch
def visualize_weight_sharing():
"""
可视化共享权重的计算过程
"""
batch_size = 3
input_features = 4
output_features = 2
# 共享的权重和偏置
W = torch.randn(output_features, input_features) # [2, 4]
b = torch.randn(output_features) # [2]
# 批次中不同的输入样本
inputs = torch.randn(batch_size, input_features) # [3, 4]
print("=== 共享权重计算过程 ===")
print(f"共享权重 W 形状: {W.shape}")
print(f"共享偏置 b 形状: {b.shape}")
print(f"输入批次形状: {inputs.shape}")
print(f"权重 W:\n{W}")
print(f"偏置 b: {b}")
print(f"输入批次:\n{inputs}")
# 手动计算每个样本的输出
print("\n--- 每个样本的计算 ---")
for i in range(batch_size):
single_input = inputs[i] # [4]
single_output = torch.matmul(single_input, W.t()) + b # [2]
print(f"样本 {i}: {single_input} × W^T + b = {single_output}")
# 批量计算(实际使用的方法)
batch_output = torch.matmul(inputs, W.t()) + b # [3, 2]
print(f"\n批量计算结果:\n{batch_output}")
visualize_weight_sharing()
# === 共享权重计算过程 ===
# 共享权重 W 形状: torch.Size([2, 4])
# 共享偏置 b 形状: torch.Size([2])
# 输入批次形状: torch.Size([3, 4])
# 权重 W:
# tensor([[-0.0892, -1.9276, -2.0218, 0.0952],
# [ 0.9642, -0.8301, -1.4846, 0.0314]])
# 偏置 b: tensor([-0.7340, 1.6211])
# 输入批次:
# tensor([[ 1.0523, 0.3266, -0.3225, 1.3272],
# [ 0.7284, -1.2795, -0.8708, 0.6347],
# [ 0.0878, -1.0075, 0.0264, -0.9797]])
# --- 每个样本的计算 ---
# 样本 0: tensor([ 1.0523, 0.3266, -0.3225, 1.3272]) × W^T + b = tensor([-0.6791, 2.8850])
# 样本 1: tensor([ 0.7284, -1.2795, -0.8708, 0.6347]) × W^T + b = tensor([3.4882, 4.6982])
# 样本 2: tensor([ 0.0878, -1.0075, 0.0264, -0.9797]) × W^T + b = tensor([1.0535, 2.4722])
# 批量计算结果:
# tensor([[-0.6791, 2.8850],
# [ 3.4882, 4.6982],
# [ 1.0535, 2.4722]])
import inspect
print(inspect.getdoc(torch.matmul))
# >>> # vector x vector
# >>> tensor1 = torch.randn(3)
# >>> tensor2 = torch.randn(3)
# >>> torch.matmul(tensor1, tensor2).size()
# torch.Size([])
# >>> # matrix x vector
# >>> tensor1 = torch.randn(3, 4)
# >>> tensor2 = torch.randn(4)
# >>> torch.matmul(tensor1, tensor2).size()
# torch.Size([3])
# >>> # batched matrix x broadcasted vector
# >>> tensor1 = torch.randn(10, 3, 4)
# >>> tensor2 = torch.randn(4)
# >>> torch.matmul(tensor1, tensor2).size()
# torch.Size([10, 3])
# >>> # batched matrix x batched matrix
# >>> tensor1 = torch.randn(10, 3, 4)
# >>> tensor2 = torch.randn(10, 4, 5)
# >>> torch.matmul(tensor1, tensor2).size()
# torch.Size([10, 3, 5])
# >>> # batched matrix x broadcasted matrix
# >>> tensor1 = torch.randn(10, 3, 4)
# >>> tensor2 = torch.randn(4, 5)
# >>> torch.matmul(tensor1, tensor2).size()
# torch.Size([10, 3, 5])