Channel-wise Max Pooling
Description
Perform max pooling over the channel dimension. Given a tensor of shape (B, C, H, W), find the maximum value across all channels for each spatial location. The output should have the shape (B, H, W).
Starter Code
import torch
from einops import reduce
def channel_max_pool(tensor):
# Your einops code here
pass
Verification
For an input of shape (16, 3, 224, 224), the output should have the shape (16, 224, 224).