Skip to content

Commit f0bccc9

Browse files
franklinicclaude
andcommitted
fix: add flat indexing methods to TensorBase and fix tensor access patterns
Add GetFlat/SetFlat methods to TensorBase for accessing tensor data by linear index, fixing test failures in gradient correctness tests that were using single-integer indexing on multi-dimensional tensors. - Add GetFlat(int flatIndex) and SetFlat(int flatIndex, T value) to TensorBase - Fix GumbelSoftmax, TaylorSoftmax, and Pad methods in TensorOperations.cs - Update GradientCorrectnessTests to use flat indexing helper methods 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent c80e1d3 commit f0bccc9

File tree

3 files changed

+161
-97
lines changed

3 files changed

+161
-97
lines changed

src/AiDotNet.Tensors/LinearAlgebra/TensorBase.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,40 @@ internal Span<T> AsWritableSpan()
259259
return _data.AsWritableSpan();
260260
}
261261

262+
/// <summary>
263+
/// Gets the value at a flat (linear) index in the underlying data.
264+
/// </summary>
265+
/// <param name="flatIndex">The flat index (0 to Length-1).</param>
266+
/// <returns>The value at the specified flat index.</returns>
267+
/// <remarks>
268+
/// <para><b>For Beginners:</b> This allows accessing tensor elements using a single
269+
/// index that treats the tensor as a 1D array. The flat index corresponds to
270+
/// row-major ordering where the last dimension varies fastest.</para>
271+
/// </remarks>
272+
public T GetFlat(int flatIndex)
273+
{
274+
if (flatIndex < 0 || flatIndex >= Length)
275+
throw new ArgumentOutOfRangeException(nameof(flatIndex), "Flat index is out of range.");
276+
return _data[flatIndex];
277+
}
278+
279+
/// <summary>
280+
/// Sets the value at a flat (linear) index in the underlying data.
281+
/// </summary>
282+
/// <param name="flatIndex">The flat index (0 to Length-1).</param>
283+
/// <param name="value">The value to set.</param>
284+
/// <remarks>
285+
/// <para><b>For Beginners:</b> This allows setting tensor elements using a single
286+
/// index that treats the tensor as a 1D array. The flat index corresponds to
287+
/// row-major ordering where the last dimension varies fastest.</para>
288+
/// </remarks>
289+
public void SetFlat(int flatIndex, T value)
290+
{
291+
if (flatIndex < 0 || flatIndex >= Length)
292+
throw new ArgumentOutOfRangeException(nameof(flatIndex), "Flat index is out of range.");
293+
_data[flatIndex] = value;
294+
}
295+
262296
/// <summary>
263297
/// Returns a string representation of the tensor.
264298
/// </summary>

src/Autodiff/TensorOperations.cs

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2491,7 +2491,7 @@ public static ComputationNode<T> Pad(ComputationNode<T> a, int[,] padWidth, T? v
24912491
// Initialize with pad value
24922492
for (int i = 0; i < result.Length; i++)
24932493
{
2494-
result[i] = padValue;
2494+
result.SetFlat(i, padValue);
24952495
}
24962496
// Copy input data to center
24972497
for (int r = 0; r < inputRows; r++)
@@ -2546,7 +2546,7 @@ void BackwardFunction(Tensor<T> gradient)
25462546
// Initialize with pad value
25472547
for (int i = 0; i < result.Length; i++)
25482548
{
2549-
result[i] = padValue;
2549+
result.SetFlat(i, padValue);
25502550
}
25512551

25522552
// Copy input data to appropriate location
@@ -7515,15 +7515,15 @@ public static ComputationNode<T> GumbelSoftmax(ComputationNode<T> logits, double
75157515
var u = random.NextDouble();
75167516
u = Math.Max(u, eps);
75177517
u = Math.Min(u, 1 - eps);
7518-
gumbel[i] = numOps.FromDouble(-Math.Log(-Math.Log(u)));
7518+
gumbel.SetFlat(i, numOps.FromDouble(-Math.Log(-Math.Log(u))));
75197519
}
75207520

75217521
// Compute soft samples: softmax((logits + gumbel) / temperature)
75227522
var tempTensor = new Tensor<T>(shape);
75237523
for (int i = 0; i < tempTensor.Length; i++)
75247524
{
7525-
var val = numOps.Add(logits.Value[i], gumbel[i]);
7526-
tempTensor[i] = numOps.Divide(val, numOps.FromDouble(temperature));
7525+
var val = numOps.Add(logits.Value.GetFlat(i), gumbel.GetFlat(i));
7526+
tempTensor.SetFlat(i, numOps.Divide(val, numOps.FromDouble(temperature)));
75277527
}
75287528

75297529
// Apply softmax along last axis
@@ -7541,18 +7541,18 @@ public static ComputationNode<T> GumbelSoftmax(ComputationNode<T> logits, double
75417541
for (int b = 0; b < batchSize; b++)
75427542
{
75437543
int maxIdx = 0;
7544-
T maxVal = softResult[b * lastDim];
7544+
T maxVal = softResult.GetFlat(b * lastDim);
75457545
for (int i = 1; i < lastDim; i++)
75467546
{
7547-
if (numOps.GreaterThan(softResult[b * lastDim + i], maxVal))
7547+
if (numOps.GreaterThan(softResult.GetFlat(b * lastDim + i), maxVal))
75487548
{
7549-
maxVal = softResult[b * lastDim + i];
7549+
maxVal = softResult.GetFlat(b * lastDim + i);
75507550
maxIdx = i;
75517551
}
75527552
}
75537553
for (int i = 0; i < lastDim; i++)
75547554
{
7555-
hardResult[b * lastDim + i] = i == maxIdx ? numOps.One : numOps.Zero;
7555+
hardResult.SetFlat(b * lastDim + i, i == maxIdx ? numOps.One : numOps.Zero);
75567556
}
75577557
}
75587558

@@ -9112,7 +9112,7 @@ public static ComputationNode<T> TaylorSoftmax(ComputationNode<T> a, int order =
91129112
for (int i = 0; i < axisSize; i++)
91139113
{
91149114
int flatIdx = outer * axisSize * innerSize + i * innerSize + inner;
9115-
var x = a.Value[flatIdx];
9115+
var x = a.Value.GetFlat(flatIdx);
91169116
var taylorExp = numOps.One; // Start with 1
91179117
var xPower = numOps.One;
91189118

@@ -9128,15 +9128,15 @@ public static ComputationNode<T> TaylorSoftmax(ComputationNode<T> a, int order =
91289128
? taylorExp
91299129
: numOps.FromDouble(1e-10);
91309130

9131-
taylorExpValues[flatIdx] = taylorExp;
9131+
taylorExpValues.SetFlat(flatIdx, taylorExp);
91329132
expSum = numOps.Add(expSum, taylorExp);
91339133
}
91349134

91359135
// Normalize
91369136
for (int i = 0; i < axisSize; i++)
91379137
{
91389138
int flatIdx = outer * axisSize * innerSize + i * innerSize + inner;
9139-
result[flatIdx] = numOps.Divide(taylorExpValues[flatIdx], expSum);
9139+
result.SetFlat(flatIdx, numOps.Divide(taylorExpValues.GetFlat(flatIdx), expSum));
91409140
}
91419141
}
91429142
}
@@ -9163,7 +9163,7 @@ void BackwardFunction(Tensor<T> gradient)
91639163
for (int i = 0; i < capturedAxisSize; i++)
91649164
{
91659165
int flatIdx = outer * capturedAxisSize * capturedInnerSize + i * capturedInnerSize + inner;
9166-
expSum = numOps.Add(expSum, taylorExpValues[flatIdx]);
9166+
expSum = numOps.Add(expSum, taylorExpValues.GetFlat(flatIdx));
91679167
}
91689168

91699169
// Softmax-style Jacobian: s_i * (δ_ij - s_j)
@@ -9172,19 +9172,19 @@ void BackwardFunction(Tensor<T> gradient)
91729172
{
91739173
int flatIdx = outer * capturedAxisSize * capturedInnerSize + i * capturedInnerSize + inner;
91749174
dotProduct = numOps.Add(dotProduct,
9175-
numOps.Multiply(gradient[flatIdx], result[flatIdx]));
9175+
numOps.Multiply(gradient.GetFlat(flatIdx), result.GetFlat(flatIdx)));
91769176
}
91779177

91789178
for (int i = 0; i < capturedAxisSize; i++)
91799179
{
91809180
int flatIdx = outer * capturedAxisSize * capturedInnerSize + i * capturedInnerSize + inner;
91819181
// Softmax gradient part: s_i * (grad_i - dot(grad, s))
9182-
var softmaxGrad = numOps.Multiply(result[flatIdx],
9183-
numOps.Subtract(gradient[flatIdx], dotProduct));
9182+
var softmaxGrad = numOps.Multiply(result.GetFlat(flatIdx),
9183+
numOps.Subtract(gradient.GetFlat(flatIdx), dotProduct));
91849184

91859185
// Taylor exp derivative: d/dx[1 + x + x²/2! + ... + x^n/n!] = 1 + x + ... + x^(n-1)/(n-1)!
91869186
// This is Taylor_{n-1}(x) for exp
9187-
var x = a.Value[flatIdx];
9187+
var x = a.Value.GetFlat(flatIdx);
91889188
var taylorExpDeriv = numOps.One;
91899189
var xPower = numOps.One;
91909190
for (int n = 1; n < capturedOrder; n++)
@@ -9197,9 +9197,9 @@ void BackwardFunction(Tensor<T> gradient)
91979197
// For y_i = g(x_i) / sum_j(g(x_j)), the chain rule requires:
91989198
// grad_x_i = softmaxGrad * g'(x_i) / g(x_i)
91999199
// where g is the Taylor approximation of exp
9200-
var gVal = taylorExpValues[flatIdx];
9200+
var gVal = taylorExpValues.GetFlat(flatIdx);
92019201
var gPrimeOverG = numOps.Divide(taylorExpDeriv, gVal);
9202-
gradA[flatIdx] = numOps.Multiply(softmaxGrad, gPrimeOverG);
9202+
gradA.SetFlat(flatIdx, numOps.Multiply(softmaxGrad, gPrimeOverG));
92039203
}
92049204
}
92059205
}

0 commit comments

Comments
 (0)