-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
Problem statement
The BLAS level-2 geru kernel (general rank-1 update) is not implemented for the PyTorch backend in kernel_course.pytorch_ops. The README BLAS table lists no PyTorch support for geru, leaving this important operation absent from the PyTorch kernel set.
Without a PyTorch geru kernel:
- there is no unified API for
$A = A + \alpha x y^\top$ on PyTorch tensors withinkernel-course, - cross-backend comparisons for rank-1 updates are not possible,
- higher-level modules must use ad-hoc PyTorch code instead of a standard kernel abstraction.
Proposed solution
Add a PyTorch implementation of geru under kernel_course.pytorch_ops that matches the Python reference semantics and follows existing backend conventions.
Concretely:
- Introduce
kernel_course/pytorch_ops/geru.pyimplementing$A = A + \alpha x y^\top$ for PyTorch tensors. - Accept scalar
alpha, vectorx, vectory, and matrixA. - Use idiomatic PyTorch operations (e.g. outer product plus in-place add) while maintaining correctness.
Alternatives considered
Relying exclusively on raw PyTorch operations outside of a dedicated kernel would:
- fragment API usage across the project,
- complicate benchmarking
geruacross backends, - weaken the pedagogical mapping between mathematical kernels and backend implementations.
Implementation details
- Add
kernel_course/pytorch_ops/geru.pywith a publicgerufunction. - Validate shapes so that
Ahas dimensions(m, n),xhas lengthm, andylengthn. - Ensure consistent dtype and device handling with other PyTorch kernels.
- Update
kernel_course/pytorch_ops/__init__.pyto exposegeru.
Use case
The PyTorch geru kernel will:
- provide a standard rank-1 update primitive for use in higher-level modules,
- serve as a target for cross-backend performance and correctness comparisons,
- help illustrate BLAS level-2 patterns in PyTorch.
Related work
- Existing PyTorch kernels:
pytorch_ops.copy,pytorch_ops.swap. - PyTorch outer product and rank-update patterns.
Additional context
This issue contributes to completing the geru row for the PyTorch column in the README BLAS table.
Metadata
Metadata
Assignees
Labels
No labels