You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
fix: address PR #514 review comments for validation and docs
- Update JIT-Compiler-Usage-Guide.md to reflect backward pass support
- Remove outdated "backward pass not yet supported" statement
- Mark backward pass as completed in Future Enhancements
- Add CompileBackward API methods to documentation
- Add comprehensive validation to Conv2DBackwardInput and Conv2DBackwardKernel
- Validate tensor ranks, array lengths, stride/dilation positivity
- Check batch/channel consistency between gradOutput, input, and kernel
- Add validation to ConvTranspose2D
- Validate rank-4 tensors, stride/padding/outputPadding arrays
- Check channel consistency between input and kernel
- Add robust axis validation to ReduceMax/ReduceMean/ReduceMeanBackward
- Add ValidateAndNormalizeAxes helper method
- Check for null, range validity, and duplicates
- Normalize negative indices properly
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
if(inputShape==null||inputShape.Length!=4)thrownewArgumentException("inputShape must be array of 4 elements [batch, inChannels, height, width]",nameof(inputShape));
if(stride==null||stride.Length!=2)thrownewArgumentException("Stride must be array of 2 elements",nameof(stride));
2450
+
if(stride[0]<=0||stride[1]<=0)thrownewArgumentException("Stride elements must be positive",nameof(stride));
2451
+
if(padding==null||padding.Length!=2)thrownewArgumentException("Padding must be array of 2 elements",nameof(padding));
2452
+
if(dilation==null||dilation.Length!=2)thrownewArgumentException("Dilation must be array of 2 elements",nameof(dilation));
2453
+
if(dilation[0]<=0||dilation[1]<=0)thrownewArgumentException("Dilation elements must be positive",nameof(dilation));
2454
+
if(gradOutput.Shape[0]!=inputShape[0])thrownewArgumentException($"gradOutput batch size ({gradOutput.Shape[0]}) must match inputShape batch size ({inputShape[0]})");
2455
+
if(gradOutput.Shape[1]!=kernel.Shape[0])thrownewArgumentException($"gradOutput outChannels ({gradOutput.Shape[1]}) must match kernel outChannels ({kernel.Shape[0]})");
2456
+
if(inputShape[1]!=kernel.Shape[1])thrownewArgumentException($"inputShape inChannels ({inputShape[1]}) must match kernel inChannels ({kernel.Shape[1]})");
2446
2457
2447
2458
varnumOps=MathHelper.GetNumericOperations<T>();
2448
2459
intbatch=inputShape[0];
@@ -2512,6 +2523,17 @@ public Tensor<T> Conv2DBackwardKernel<T>(Tensor<T> gradOutput, Tensor<T> input,
if(kernelShape==null||kernelShape.Length!=4)thrownewArgumentException("kernelShape must be array of 4 elements [outChannels, inChannels, kernelHeight, kernelWidth]",nameof(kernelShape));
if(stride==null||stride.Length!=2)thrownewArgumentException("Stride must be array of 2 elements",nameof(stride));
2530
+
if(stride[0]<=0||stride[1]<=0)thrownewArgumentException("Stride elements must be positive",nameof(stride));
2531
+
if(padding==null||padding.Length!=2)thrownewArgumentException("Padding must be array of 2 elements",nameof(padding));
2532
+
if(dilation==null||dilation.Length!=2)thrownewArgumentException("Dilation must be array of 2 elements",nameof(dilation));
2533
+
if(dilation[0]<=0||dilation[1]<=0)thrownewArgumentException("Dilation elements must be positive",nameof(dilation));
2534
+
if(gradOutput.Shape[0]!=input.Shape[0])thrownewArgumentException($"gradOutput batch size ({gradOutput.Shape[0]}) must match input batch size ({input.Shape[0]})");
2535
+
if(gradOutput.Shape[1]!=kernelShape[0])thrownewArgumentException($"gradOutput outChannels ({gradOutput.Shape[1]}) must match kernelShape outChannels ({kernelShape[0]})");
2536
+
if(input.Shape[1]!=kernelShape[1])thrownewArgumentException($"input inChannels ({input.Shape[1]}) must match kernelShape inChannels ({kernelShape[1]})");
0 commit comments