torch.lu_solve

torch.lu_solve(b, LU_data, LU_pivots, *, out=None) → Tensor

Returns the LU solve of the linear system Ax=bAx = b using the partially pivoted LU factorization of A from torch.lu().

This function supports float, double, cfloat and cdouble dtypes for input.

Parameters
  • b (Tensor) – the RHS tensor of size (,m,k)(*, m, k) , where * is zero or more batch dimensions.
  • LU_data (Tensor) – the pivoted LU factorization of A from torch.lu() of size (,m,m)(*, m, m) , where * is zero or more batch dimensions.
  • LU_pivots (IntTensor) – the pivots of the LU factorization from torch.lu() of size (,m)(*, m) , where * is zero or more batch dimensions. The batch dimensions of LU_pivots must be equal to the batch dimensions of LU_data.
Keyword Arguments

out (Tensor, optional) – the output tensor.

Example:

>>> A = torch.randn(2, 3, 3)
>>> b = torch.randn(2, 3, 1)
>>> A_LU = torch.lu(A)
>>> x = torch.lu_solve(b, *A_LU)
>>> torch.norm(torch.bmm(A, x) - b)
tensor(1.00000e-07 *
       2.8312)

© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.lu_solve.html