PyTorch的矩阵乘法

下载torch4matmul.ipynb
Python示例代码

        # 主要区别总结:
        # @ 和 torch.matmul(): 支持广播,适用于各种维度
        # torch.mm(): 仅支持2D矩阵,不广播
        # *: 逐元素乘法,要求形状相同
        # torch.dot(): 仅支持1D向量的点积
        # 推荐使用 @ 运算符,因为它最直观且功能全面
        
        import torch

    #矩阵乘法@
    x = torch.tensor([1,2,3,4,5,6]).reshape(2,3)
    y = torch.tensor([1,2,3,4,5,6]).reshape(3,2)
    out=x @ y
    print(x)
    print(y)
    print(out)

    # tensor([[1, 2, 3],
    #         [4, 5, 6]])
    # tensor([[1, 2],
    #         [3, 4],
    #         [5, 6]])
    # tensor([[22, 28],
    #         [49, 64]])
    
    #矩阵乘法torch.matmul()
    out=torch.matmul(x, y)
    print(out)
    # tensor([[22, 28],
    #         [49, 64]])
    
        

    #  Element-wise元素级别的乘法
    x = torch.tensor([1,2,3,4]).reshape(2, 2)
    y = torch.tensor([2,2,2,2]).reshape(2, 2)
    element_wise= x*y
    print(x)
    print(y)
    print(element_wise)

    # tensor([[1, 2],
    #         [3, 4]])
    # tensor([[2, 2],
    #         [2, 2]])
    # tensor([[2, 4],
    #         [6, 8]])
    
    

        #后两维做矩阵乘法
        batch_x = torch.randn(3, 2, 3)
        batch_y = torch.randn(3, 3, 2)

        print(batch_x)
        print(batch_y)

        out = torch.matmul(batch_x, batch_y)
        print(out)
        
        
        #验证上面的代码(batchsize相当于,batch_x的第一个维度相当于for循环(批量数据处理))
        x = torch.arange(1, 19).reshape(3, 2, 3)
        y = torch.arange(1, 19).reshape(3, 3, 2)

        print(x)
        print(y)
        out = torch.matmul(x, y)
        print(out)

        # tensor([[[ 1,  2,  3],
        #          [ 4,  5,  6]],

        #         [[ 7,  8,  9],
        #          [10, 11, 12]],

        #         [[13, 14, 15],
        #          [16, 17, 18]]])
        # tensor([[[ 1,  2],
        #          [ 3,  4],
        #          [ 5,  6]],

        #         [[ 7,  8],
        #          [ 9, 10],
        #          [11, 12]],

        #         [[13, 14],
        #          [15, 16],
        #          [17, 18]]])
        # tensor([[[ 22,  28],
        #          [ 49,  64]],

        #         [[220, 244],
        #          [301, 334]],

        #         [[634, 676],
        #          [769, 820]]])
    
    

        x = torch.arange(1, 7).reshape(2, 3)
        y = torch.arange(1, 7).reshape(3, 2)

        print(x)
        print(y)
        out = torch.matmul(x, y)
        print(out)

        # tensor([[1, 2, 3],
        #         [4, 5, 6]])
        # tensor([[1, 2],
        #         [3, 4],
        #         [5, 6]])
        # tensor([[22, 28],
        #         [49, 64]])