Skip to content

Commit 1fd7cf1

Browse files
franklinicclaude
andcommitted
fix: address PR #514 review comments for validation and docs
- Update JIT-Compiler-Usage-Guide.md to reflect backward pass support - Remove outdated "backward pass not yet supported" statement - Mark backward pass as completed in Future Enhancements - Add CompileBackward API methods to documentation - Add comprehensive validation to Conv2DBackwardInput and Conv2DBackwardKernel - Validate tensor ranks, array lengths, stride/dilation positivity - Check batch/channel consistency between gradOutput, input, and kernel - Add validation to ConvTranspose2D - Validate rank-4 tensors, stride/padding/outputPadding arrays - Check channel consistency between input and kernel - Add robust axis validation to ReduceMax/ReduceMean/ReduceMeanBackward - Add ValidateAndNormalizeAxes helper method - Check for null, range validity, and duplicates - Normalize negative indices properly 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 8fa368a commit 1fd7cf1

File tree

2 files changed

+82
-6
lines changed

2 files changed

+82
-6
lines changed

docs/JIT-Compiler-Usage-Guide.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ if (!stats.CacheHit)
206206
- Production deployments
207207

208208
**Less beneficial for:**
209-
- Training (backward pass not yet supported)
210209
- Graphs that change structure frequently
211210
- Very small operations (compilation overhead)
212211

@@ -291,7 +290,7 @@ if (stats.EstimatedMemoryBytes > threshold)
291290
## Future Enhancements
292291

293292
Planned improvements:
294-
- [ ] Support for backward pass (gradient) compilation
293+
- [x] Support for backward pass (gradient) compilation
295294
- [ ] GPU code generation
296295
- [ ] More fusion patterns
297296
- [ ] Advanced optimizations (loop unrolling, vectorization hints)
@@ -313,6 +312,12 @@ See the `examples/JitCompilerExample.cs` file for complete working examples.
313312
- `(Func<Tensor<T>[], Tensor<T>[]>, CompilationStats) CompileWithStats<T>(...)`
314313
- Compiles and returns statistics
315314

315+
- `Func<Tensor<T>[], Tensor<T>[]> CompileBackward<T>(ComputationNode<T> outputNode, List<ComputationNode<T>> inputs)`
316+
- Compiles a backward pass (gradient computation) graph to executable code
317+
318+
- `(Func<Tensor<T>[], Tensor<T>[]>, CompilationStats) CompileBackwardWithStats<T>(...)`
319+
- Compiles backward pass and returns statistics
320+
316321
- `void ClearCache()`
317322
- Clears the compiled graph cache
318323

src/AiDotNet.Tensors/Engines/CpuEngine.cs

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2443,6 +2443,17 @@ public Tensor<T> Conv2DBackwardInput<T>(Tensor<T> gradOutput, Tensor<T> kernel,
24432443
{
24442444
if (gradOutput == null) throw new ArgumentNullException(nameof(gradOutput));
24452445
if (kernel == null) throw new ArgumentNullException(nameof(kernel));
2446+
if (inputShape == null || inputShape.Length != 4) throw new ArgumentException("inputShape must be array of 4 elements [batch, inChannels, height, width]", nameof(inputShape));
2447+
if (gradOutput.Rank != 4) throw new ArgumentException($"Conv2DBackwardInput requires 4D gradOutput tensor. Got rank {gradOutput.Rank}.", nameof(gradOutput));
2448+
if (kernel.Rank != 4) throw new ArgumentException($"Conv2DBackwardInput requires 4D kernel tensor. Got rank {kernel.Rank}.", nameof(kernel));
2449+
if (stride == null || stride.Length != 2) throw new ArgumentException("Stride must be array of 2 elements", nameof(stride));
2450+
if (stride[0] <= 0 || stride[1] <= 0) throw new ArgumentException("Stride elements must be positive", nameof(stride));
2451+
if (padding == null || padding.Length != 2) throw new ArgumentException("Padding must be array of 2 elements", nameof(padding));
2452+
if (dilation == null || dilation.Length != 2) throw new ArgumentException("Dilation must be array of 2 elements", nameof(dilation));
2453+
if (dilation[0] <= 0 || dilation[1] <= 0) throw new ArgumentException("Dilation elements must be positive", nameof(dilation));
2454+
if (gradOutput.Shape[0] != inputShape[0]) throw new ArgumentException($"gradOutput batch size ({gradOutput.Shape[0]}) must match inputShape batch size ({inputShape[0]})");
2455+
if (gradOutput.Shape[1] != kernel.Shape[0]) throw new ArgumentException($"gradOutput outChannels ({gradOutput.Shape[1]}) must match kernel outChannels ({kernel.Shape[0]})");
2456+
if (inputShape[1] != kernel.Shape[1]) throw new ArgumentException($"inputShape inChannels ({inputShape[1]}) must match kernel inChannels ({kernel.Shape[1]})");
24462457

24472458
var numOps = MathHelper.GetNumericOperations<T>();
24482459
int batch = inputShape[0];
@@ -2512,6 +2523,17 @@ public Tensor<T> Conv2DBackwardKernel<T>(Tensor<T> gradOutput, Tensor<T> input,
25122523
{
25132524
if (gradOutput == null) throw new ArgumentNullException(nameof(gradOutput));
25142525
if (input == null) throw new ArgumentNullException(nameof(input));
2526+
if (kernelShape == null || kernelShape.Length != 4) throw new ArgumentException("kernelShape must be array of 4 elements [outChannels, inChannels, kernelHeight, kernelWidth]", nameof(kernelShape));
2527+
if (gradOutput.Rank != 4) throw new ArgumentException($"Conv2DBackwardKernel requires 4D gradOutput tensor. Got rank {gradOutput.Rank}.", nameof(gradOutput));
2528+
if (input.Rank != 4) throw new ArgumentException($"Conv2DBackwardKernel requires 4D input tensor. Got rank {input.Rank}.", nameof(input));
2529+
if (stride == null || stride.Length != 2) throw new ArgumentException("Stride must be array of 2 elements", nameof(stride));
2530+
if (stride[0] <= 0 || stride[1] <= 0) throw new ArgumentException("Stride elements must be positive", nameof(stride));
2531+
if (padding == null || padding.Length != 2) throw new ArgumentException("Padding must be array of 2 elements", nameof(padding));
2532+
if (dilation == null || dilation.Length != 2) throw new ArgumentException("Dilation must be array of 2 elements", nameof(dilation));
2533+
if (dilation[0] <= 0 || dilation[1] <= 0) throw new ArgumentException("Dilation elements must be positive", nameof(dilation));
2534+
if (gradOutput.Shape[0] != input.Shape[0]) throw new ArgumentException($"gradOutput batch size ({gradOutput.Shape[0]}) must match input batch size ({input.Shape[0]})");
2535+
if (gradOutput.Shape[1] != kernelShape[0]) throw new ArgumentException($"gradOutput outChannels ({gradOutput.Shape[1]}) must match kernelShape outChannels ({kernelShape[0]})");
2536+
if (input.Shape[1] != kernelShape[1]) throw new ArgumentException($"input inChannels ({input.Shape[1]}) must match kernelShape inChannels ({kernelShape[1]})");
25152537

25162538
var numOps = MathHelper.GetNumericOperations<T>();
25172539

@@ -3019,6 +3041,15 @@ public Tensor<T> ConvTranspose2D<T>(Tensor<T> input, Tensor<T> kernel, int[] str
30193041
{
30203042
if (input == null) throw new ArgumentNullException(nameof(input));
30213043
if (kernel == null) throw new ArgumentNullException(nameof(kernel));
3044+
if (input.Rank != 4) throw new ArgumentException($"ConvTranspose2D requires 4D input tensor. Got rank {input.Rank}.", nameof(input));
3045+
if (kernel.Rank != 4) throw new ArgumentException($"ConvTranspose2D requires 4D kernel tensor. Got rank {kernel.Rank}.", nameof(kernel));
3046+
if (stride == null || stride.Length != 2) throw new ArgumentException("Stride must be array of 2 elements", nameof(stride));
3047+
if (stride[0] <= 0 || stride[1] <= 0) throw new ArgumentException("Stride elements must be positive", nameof(stride));
3048+
if (padding == null || padding.Length != 2) throw new ArgumentException("Padding must be array of 2 elements", nameof(padding));
3049+
if (padding[0] < 0 || padding[1] < 0) throw new ArgumentException("Padding elements must be non-negative", nameof(padding));
3050+
if (outputPadding == null || outputPadding.Length != 2) throw new ArgumentException("OutputPadding must be array of 2 elements", nameof(outputPadding));
3051+
if (outputPadding[0] < 0 || outputPadding[1] < 0) throw new ArgumentException("OutputPadding elements must be non-negative", nameof(outputPadding));
3052+
if (input.Shape[1] != kernel.Shape[0]) throw new ArgumentException($"Input inChannels ({input.Shape[1]}) must match kernel inChannels ({kernel.Shape[0]})");
30223053

30233054
var numOps = MathHelper.GetNumericOperations<T>();
30243055

@@ -4092,15 +4123,50 @@ public Tensor<T> LayerNormBackward<T>(Tensor<T> gradOutput, Tensor<T> input, Ten
40924123

40934124
#region Tensor Reduction Operations
40944125

4126+
/// <summary>
4127+
/// Validates and normalizes reduction axes.
4128+
/// </summary>
4129+
/// <param name="axes">The axes to validate</param>
4130+
/// <param name="rank">The tensor rank</param>
4131+
/// <returns>Normalized, validated, and sorted unique axes</returns>
4132+
private static int[] ValidateAndNormalizeAxes(int[] axes, int rank)
4133+
{
4134+
if (axes == null)
4135+
throw new ArgumentNullException(nameof(axes), "Axes cannot be null");
4136+
4137+
if (axes.Length == 0)
4138+
throw new ArgumentException("Axes array cannot be empty", nameof(axes));
4139+
4140+
var normalizedAxes = new int[axes.Length];
4141+
for (int i = 0; i < axes.Length; i++)
4142+
{
4143+
int axis = axes[i];
4144+
// Normalize negative indices
4145+
int normalized = axis < 0 ? rank + axis : axis;
4146+
4147+
if (normalized < 0 || normalized >= rank)
4148+
throw new ArgumentOutOfRangeException(nameof(axes), $"Axis {axis} is out of range for tensor with rank {rank}. Valid range is [{-rank}, {rank - 1}].");
4149+
4150+
normalizedAxes[i] = normalized;
4151+
}
4152+
4153+
// Check for duplicates
4154+
var uniqueAxes = normalizedAxes.Distinct().ToArray();
4155+
if (uniqueAxes.Length != axes.Length)
4156+
throw new ArgumentException("Duplicate axes are not allowed", nameof(axes));
4157+
4158+
return uniqueAxes.OrderBy(a => a).ToArray();
4159+
}
4160+
40954161
/// <inheritdoc/>
40964162
public Tensor<T> ReduceMax<T>(Tensor<T> input, int[] axes, bool keepDims, out int[] maxIndices)
40974163
{
40984164
var numOps = MathHelper.GetNumericOperations<T>();
40994165
var inputShape = input.Shape;
41004166
var inputData = input.ToArray();
41014167

4102-
// Normalize axes
4103-
var normalizedAxes = axes.Select(a => a < 0 ? inputShape.Length + a : a).OrderBy(a => a).ToArray();
4168+
// Validate and normalize axes
4169+
var normalizedAxes = ValidateAndNormalizeAxes(axes, inputShape.Length);
41044170

41054171
// Compute output shape
41064172
var outputShapeList = new List<int>();
@@ -4192,7 +4258,8 @@ public Tensor<T> ReduceMean<T>(Tensor<T> input, int[] axes, bool keepDims)
41924258
var inputShape = input.Shape;
41934259
var inputData = input.ToArray();
41944260

4195-
var normalizedAxes = axes.Select(a => a < 0 ? inputShape.Length + a : a).OrderBy(a => a).ToArray();
4261+
// Validate and normalize axes
4262+
var normalizedAxes = ValidateAndNormalizeAxes(axes, inputShape.Length);
41964263

41974264
var outputShapeList = new List<int>();
41984265
for (int i = 0; i < inputShape.Length; i++)
@@ -4258,11 +4325,15 @@ public Tensor<T> ReduceMean<T>(Tensor<T> input, int[] axes, bool keepDims)
42584325
/// <inheritdoc/>
42594326
public Tensor<T> ReduceMeanBackward<T>(Tensor<T> gradOutput, int[] inputShape, int[] axes)
42604327
{
4328+
if (inputShape == null || inputShape.Length == 0)
4329+
throw new ArgumentNullException(nameof(inputShape), "inputShape cannot be null or empty");
4330+
42614331
var numOps = MathHelper.GetNumericOperations<T>();
42624332
int inputSize = inputShape.Aggregate(1, (a, b) => a * b);
42634333
var gradInputData = new T[inputSize];
42644334

4265-
var normalizedAxes = axes.Select(a => a < 0 ? inputShape.Length + a : a).ToArray();
4335+
// Validate and normalize axes
4336+
var normalizedAxes = ValidateAndNormalizeAxes(axes, inputShape.Length);
42664337

42674338
int reduceCount = 1;
42684339
foreach (var ax in normalizedAxes)

0 commit comments

Comments
 (0)