Implementing a simplified Beam Search Decoder with `gather` and `scatter`
Beam search is a popular decoding algorithm used in machine translation and text generation. A key step in beam search is to select the top-k most likely next tokens and update the corresponding beams. This exercise requires you to implement a simplified version of this step. You'll be given the current beam scores (batch_size, beam_width) and the predicted log probabilities for the next token (batch_size, vocab_size). Your task is to:
1. Calculate the new beam scores by adding the current beam scores to the log probabilities.
2. Find the top-k new scores and their corresponding indices (which will be a combination of the previous beam and the new token).
3. Use gather and scatter to update the beams and their scores for the next step.
This is a challenging exercise that mimics a real-world implementation in sequence-to-sequence models.
Verification: - The updated beam scores and indices should correctly reflect the top-k most probable sequences. - The shapes of the updated tensors should be consistent for the next iteration of the beam search.