torch.where

torch.where(condition, x, y) → Tensor

Return a tensor of elements selected from either x or y, depending on condition.

The operation is defined as:

outi={xiif conditioniyiotherwise\text{out}_i = \begin{cases} \text{x}_i & \text{if } \text{condition}_i \\ \text{y}_i & \text{otherwise} \\ \end{cases}

Note

The tensors condition, x, y must be broadcastable.

Note

Currently valid scalar and tensor combination are 1. Scalar of floating dtype and torch.double 2. Scalar of integral dtype and torch.long 3. Scalar of complex dtype and torch.complex128

Parameters
  • condition (BoolTensor) – When True (nonzero), yield x, otherwise yield y
  • x (Tensor or Scalar) – value (if :attr:x is a scalar) or values selected at indices where condition is True
  • y (Tensor or Scalar) – value (if :attr:x is a scalar) or values selected at indices where condition is False
Returns

A tensor of shape equal to the broadcasted shape of condition, x, y

Return type

Tensor

Example:

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
        [ 0.3898, -0.7197],
        [ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])
>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],
        [-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
        [0.0000, 0.0000]], dtype=torch.float64)
torch.where(condition) → tuple of LongTensor

torch.where(condition) is identical to torch.nonzero(condition, as_tuple=True).

Note

See also torch.nonzero().

© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.where.html