torch.mm

torch.mm(input, mat2, *, out=None) → Tensor

Performs a matrix multiplication of the matrices input and mat2.

If input is a (n×m)(n \times m) tensor, mat2 is a (m×p)(m \times p) tensor, out will be a (n×p)(n \times p) tensor.

Note

This function does not broadcast. For broadcasting matrix products, see torch.matmul().

Supports strided and sparse 2-D tensors as inputs, autograd with respect to strided inputs.

This operator supports TensorFloat32.

Parameters
  • input (Tensor) – the first matrix to be matrix multiplied
  • mat2 (Tensor) – the second matrix to be matrix multiplied
Keyword Arguments

out (Tensor, optional) – the output tensor.

Example:

>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 3)
>>> torch.mm(mat1, mat2)
tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.mm.html