PyTorch的矩阵乘法:Batchsize的共享权重

下载weight4batch.ipynb
Python示例代码

        import torch

        def visualize_weight_sharing():
            """
            可视化共享权重的计算过程
            """
            batch_size = 3
            input_features = 4
            output_features = 2
            
            # 共享的权重和偏置
            W = torch.randn(output_features, input_features)  # [2, 4]
            b = torch.randn(output_features)                  # [2]
            
            # 批次中不同的输入样本
            inputs = torch.randn(batch_size, input_features)  # [3, 4]
            
            print("=== 共享权重计算过程 ===")
            print(f"共享权重 W 形状: {W.shape}")
            print(f"共享偏置 b 形状: {b.shape}")
            print(f"输入批次形状: {inputs.shape}")
            print(f"权重 W:\n{W}")
            print(f"偏置 b: {b}")
            print(f"输入批次:\n{inputs}")
            
            # 手动计算每个样本的输出
            print("\n--- 每个样本的计算 ---")
            for i in range(batch_size):
                single_input = inputs[i]  # [4]
                single_output = torch.matmul(single_input, W.t()) + b  # [2]
                print(f"样本 {i}: {single_input} × W^T + b = {single_output}")
            
            # 批量计算(实际使用的方法)
            batch_output = torch.matmul(inputs, W.t()) + b  # [3, 2]
            print(f"\n批量计算结果:\n{batch_output}")

            
        visualize_weight_sharing()

            # === 共享权重计算过程 ===
            # 共享权重 W 形状: torch.Size([2, 4])
            # 共享偏置 b 形状: torch.Size([2])
            # 输入批次形状: torch.Size([3, 4])
            # 权重 W:
            # tensor([[-0.0892, -1.9276, -2.0218,  0.0952],
            #         [ 0.9642, -0.8301, -1.4846,  0.0314]])
            # 偏置 b: tensor([-0.7340,  1.6211])
            # 输入批次:
            # tensor([[ 1.0523,  0.3266, -0.3225,  1.3272],
            #         [ 0.7284, -1.2795, -0.8708,  0.6347],
            #         [ 0.0878, -1.0075,  0.0264, -0.9797]])

            # --- 每个样本的计算 ---
            # 样本 0: tensor([ 1.0523,  0.3266, -0.3225,  1.3272]) × W^T + b = tensor([-0.6791,  2.8850])
            # 样本 1: tensor([ 0.7284, -1.2795, -0.8708,  0.6347]) × W^T + b = tensor([3.4882, 4.6982])
            # 样本 2: tensor([ 0.0878, -1.0075,  0.0264, -0.9797]) × W^T + b = tensor([1.0535, 2.4722])

            # 批量计算结果:
            # tensor([[-0.6791,  2.8850],
            #         [ 3.4882,  4.6982],
            #         [ 1.0535,  2.4722]])
    
    

        import inspect

        print(inspect.getdoc(torch.matmul))


        #     >>> # vector x vector
        #     >>> tensor1 = torch.randn(3)
        #     >>> tensor2 = torch.randn(3)
        #     >>> torch.matmul(tensor1, tensor2).size()
        #     torch.Size([])
        #     >>> # matrix x vector
        #     >>> tensor1 = torch.randn(3, 4)
        #     >>> tensor2 = torch.randn(4)
        #     >>> torch.matmul(tensor1, tensor2).size()
        #     torch.Size([3])
        #     >>> # batched matrix x broadcasted vector
        #     >>> tensor1 = torch.randn(10, 3, 4)
        #     >>> tensor2 = torch.randn(4)
        #     >>> torch.matmul(tensor1, tensor2).size()
        #     torch.Size([10, 3])
        #     >>> # batched matrix x batched matrix
        #     >>> tensor1 = torch.randn(10, 3, 4)
        #     >>> tensor2 = torch.randn(10, 4, 5)
        #     >>> torch.matmul(tensor1, tensor2).size()
        #     torch.Size([10, 3, 5])
        #     >>> # batched matrix x broadcasted matrix
        #     >>> tensor1 = torch.randn(10, 3, 4)
        #     >>> tensor2 = torch.randn(4, 5)
        #     >>> torch.matmul(tensor1, tensor2).size()
        #     torch.Size([10, 3, 5])