ML Katas

Channel-wise Max Pooling

easy (<10 mins) einops reduction pooling
this year by E

Description

Perform max pooling over the channel dimension. Given a tensor of shape (B, C, H, W), find the maximum value across all channels for each spatial location. The output should have the shape (B, H, W).

Starter Code

import torch
from einops import reduce

def channel_max_pool(tensor):
    # Your einops code here
    pass

Verification

For an input of shape (16, 3, 224, 224), the output should have the shape (16, 224, 224).