torch.solve

torch.solve(input, A, *, out=None) -> (Tensor, Tensor)

This function returns the solution to the system of linear equations represented by AX=BAX = B and the LU factorization of A, in order as a namedtuple solution, LU.

LU contains L and U factors for LU factorization of A.

torch.solve(B, A) can take in 2D inputs B, A or inputs that are batches of 2D matrices. If the inputs are batches, then returns batched outputs solution, LU.

Supports real-valued and complex-valued inputs.

Note

Irrespective of the original strides, the returned matrices solution and LU will be transposed, i.e. with strides like B.contiguous().transpose(-1, -2).stride() and A.contiguous().transpose(-1, -2).stride() respectively.

Parameters
  • input (Tensor) – input matrix BB of size (,m,k)(*, m, k) , where * is zero or more batch dimensions.
  • A (Tensor) – input square matrix of size (,m,m)(*, m, m) , where * is zero or more batch dimensions.
Keyword Arguments

out ((Tensor, Tensor), optional) – optional output tuple.

Example:

>>> A = torch.tensor([[6.80, -2.11,  5.66,  5.97,  8.23],
...                   [-6.05, -3.30,  5.36, -4.44,  1.08],
...                   [-0.45,  2.58, -2.70,  0.27,  9.04],
...                   [8.32,  2.71,  4.35,  -7.17,  2.14],
...                   [-9.67, -5.14, -7.26,  6.08, -6.87]]).t()
>>> B = torch.tensor([[4.02,  6.19, -8.22, -7.57, -3.03],
...                   [-1.56,  4.00, -8.67,  1.75,  2.86],
...                   [9.81, -4.09, -4.57, -8.61,  8.99]]).t()
>>> X, LU = torch.solve(B, A)
>>> torch.dist(B, torch.mm(A, X))
tensor(1.00000e-06 *
       7.0977)

>>> # Batched solver example
>>> A = torch.randn(2, 3, 1, 4, 4)
>>> B = torch.randn(2, 3, 1, 4, 6)
>>> X, LU = torch.solve(B, A)
>>> torch.dist(B, A.matmul(X))
tensor(1.00000e-06 *
   3.6386)

© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.solve.html