ML Katas

Matrix Multiplication and Efficiency

medium (<30 mins) optimization matrix-multiplicatio linear-algebra algorithms computational-comp
this year by E

Matrix multiplication is a fundamental operation in linear algebra and a cornerstone of deep learning. Given two matrices A (size m×k) and B (size k×n), their product C=AB is an m×n matrix where each element Cij is computed as:

Cij=p=1kAipBpj

The standard algorithm for matrix multiplication has a time complexity of O(mkn). For square matrices of size N×N, this is O(N3).

Your task is twofold: 1. Implement standard matrix multiplication from scratch using nested loops. 2. (Conceptual/Research) Research Strassen's algorithm for matrix multiplication. Understand its principle and how it achieves a lower theoretical time complexity (O(Nlog27)O(N2.807)). You don't need to implement Strassen's algorithm, but be able to describe its core idea and why it's faster for very large matrices.

Implementation Details: * naive_matrix_multiply(A, B): * A: A 2D NumPy array. * B: A 2D NumPy array. * Perform checks for compatible dimensions. * Implement the multiplication using nested for loops. * Return the resulting product matrix.

Verification: 1. Create small, compatible NumPy matrices, e.g., A = np.array([[1, 2], [3, 4]]), B = np.array([[5, 6], [7, 8]]). 2. Compute the product using your naive_matrix_multiply function. 3. Compare the result with np.dot(A, B) or A @ B. The results should be identical. 4. (Conceptual) For Strassen's algorithm, describe in text (or pseudocode) its divide-and-conquer approach and the sub-problems it solves recursively. Discuss its practical implications (e.g., overhead for small matrices, cache efficiency).

Research Question (answer as part of your content): Why isn't Strassen's algorithm universally used in libraries like NumPy or BLAS, despite its better theoretical complexity?