ChannelShuffle

class torch.nn.ChannelShuffle(groups) [source]

Divide the channels in a tensor of shape (,C,H,W)(*, C , H, W) into g groups and rearrange them as (,Cg,g,H,W)(*, C \frac g, g, H, W) , while keeping the original tensor shape.

Parameters

groups (int) – number of groups to divide channels in.

Examples:

>>> channel_shuffle = nn.ChannelShuffle(2)
>>> input = torch.randn(1, 4, 2, 2)
>>> print(input)
[[[[1, 2],
   [3, 4]],
  [[5, 6],
   [7, 8]],
  [[9, 10],
   [11, 12]],
  [[13, 14],
   [15, 16]],
 ]]
>>> output = channel_shuffle(input)
>>> print(output)
[[[[1, 2],
   [3, 4]],
  [[9, 10],
   [11, 12]],
  [[5, 6],
   [7, 8]],
  [[13, 14],
   [15, 16]],
 ]]

© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.nn.ChannelShuffle.html