|
4 | 4 | import java.util.ArrayList; |
5 | 5 | import java.util.Arrays; |
6 | 6 | import java.util.List; |
| 7 | +import java.util.Objects; |
7 | 8 | import java.util.concurrent.atomic.AtomicInteger; |
8 | 9 | import java.util.function.BinaryOperator; |
9 | 10 | import java.util.function.Predicate; |
@@ -265,7 +266,28 @@ public static <T> Object reduceAlongAxis( |
265 | 266 | /** |
266 | 267 | * Private record to hold calculated information about a slice operation. |
267 | 268 | */ |
268 | | - private record SliceInfo(int[] outShape, int[] sliceStarts, int outRank) {} |
| 269 | + private record SliceInfo(int[] outShape, int[] sliceStarts, int outRank) { |
| 270 | + @Override |
| 271 | + public boolean equals(Object o) { |
| 272 | + if (o == null || getClass() != o.getClass()) return false; |
| 273 | + SliceInfo sliceInfo = (SliceInfo) o; |
| 274 | + return outRank == sliceInfo.outRank && Objects.deepEquals(outShape, sliceInfo.outShape) && Objects.deepEquals(sliceStarts, sliceInfo.sliceStarts); |
| 275 | + } |
| 276 | + |
| 277 | + @Override |
| 278 | + public int hashCode() { |
| 279 | + return Objects.hash(Arrays.hashCode(outShape), Arrays.hashCode(sliceStarts), outRank); |
| 280 | + } |
| 281 | + |
| 282 | + @Override |
| 283 | + public String toString() { |
| 284 | + return "SliceInfo{" + |
| 285 | + "outShape=" + Arrays.toString(outShape) + |
| 286 | + ", sliceStarts=" + Arrays.toString(sliceStarts) + |
| 287 | + ", outRank=" + outRank + |
| 288 | + '}'; |
| 289 | + } |
| 290 | + } |
269 | 291 |
|
270 | 292 | /** |
271 | 293 | * Calculates the output shape, starting indices, and rank of a sliced array. |
|
0 commit comments