Skip to content

[FEATURE REQUEST] geru PyTorch kernel implementation #26

@LoserCheems

Description

@LoserCheems

Problem statement

The BLAS level-2 geru kernel (general rank-1 update) is not implemented for the PyTorch backend in kernel_course.pytorch_ops. The README BLAS table lists no PyTorch support for geru, leaving this important operation absent from the PyTorch kernel set.

Without a PyTorch geru kernel:

  • there is no unified API for $A = A + \alpha x y^\top$ on PyTorch tensors within kernel-course,
  • cross-backend comparisons for rank-1 updates are not possible,
  • higher-level modules must use ad-hoc PyTorch code instead of a standard kernel abstraction.

Proposed solution

Add a PyTorch implementation of geru under kernel_course.pytorch_ops that matches the Python reference semantics and follows existing backend conventions.

Concretely:

  • Introduce kernel_course/pytorch_ops/geru.py implementing $A = A + \alpha x y^\top$ for PyTorch tensors.
  • Accept scalar alpha, vector x, vector y, and matrix A.
  • Use idiomatic PyTorch operations (e.g. outer product plus in-place add) while maintaining correctness.

Alternatives considered

Relying exclusively on raw PyTorch operations outside of a dedicated kernel would:

  • fragment API usage across the project,
  • complicate benchmarking geru across backends,
  • weaken the pedagogical mapping between mathematical kernels and backend implementations.

Implementation details

  • Add kernel_course/pytorch_ops/geru.py with a public geru function.
  • Validate shapes so that A has dimensions (m, n), x has length m, and y length n.
  • Ensure consistent dtype and device handling with other PyTorch kernels.
  • Update kernel_course/pytorch_ops/__init__.py to expose geru.

Use case

The PyTorch geru kernel will:

  • provide a standard rank-1 update primitive for use in higher-level modules,
  • serve as a target for cross-backend performance and correctness comparisons,
  • help illustrate BLAS level-2 patterns in PyTorch.

Related work

  • Existing PyTorch kernels: pytorch_ops.copy, pytorch_ops.swap.
  • PyTorch outer product and rank-update patterns.

Additional context

This issue contributes to completing the geru row for the PyTorch column in the README BLAS table.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions