diff --git a/src/JitCompiler/IR/IRGraph.cs b/src/JitCompiler/IR/IRGraph.cs new file mode 100644 index 000000000..76e4a6892 --- /dev/null +++ b/src/JitCompiler/IR/IRGraph.cs @@ -0,0 +1,265 @@ +namespace AiDotNet.JitCompiler.IR; + +/// +/// Represents a computation graph in intermediate representation form. +/// +/// +/// +/// An IRGraph is a structured representation of a sequence of tensor operations +/// that have been recorded during autodiff execution. It serves as an intermediate +/// format between the high-level ComputationNode graph and the low-level compiled code. +/// +/// For Beginners: Think of an IRGraph as a recipe for computations. +/// +/// Just like a recipe lists ingredients and steps: +/// - InputIds are the ingredients (input tensors) +/// - Operations are the cooking steps (add, multiply, etc.) +/// - OutputIds are the final dishes (output tensors) +/// - TensorShapes tells us the "size" of each intermediate result +/// +/// The IR graph makes it easier to optimize the computation (like combining steps) +/// and then compile it to fast executable code. +/// +/// Example: +/// If your model does: result = ReLU(MatMul(input, weights) + bias) +/// The IR graph would have 3 operations: MatMul, Add, ReLU +/// Each operation knows its inputs and produces an output. +/// +/// +public class IRGraph +{ + /// + /// Gets or sets the list of operations in this graph, in execution order. + /// + /// + /// + /// Operations are stored in topological order, meaning each operation appears + /// after all operations that produce its inputs. This ensures correct execution order. + /// + /// For Beginners: This is the ordered list of computation steps. + /// + /// The order matters! You can't add two numbers before you've computed them. + /// Each operation in the list uses results from earlier operations. + /// + /// + public List Operations { get; set; } = new(); + + /// + /// Gets or sets the mapping from tensor IDs to their shapes. + /// + /// + /// + /// Every tensor in the graph (inputs, outputs, and intermediates) has a unique ID + /// and a known shape (represented as int[] matching Tensor<T>.Shape). + /// This dictionary provides that mapping. + /// + /// For Beginners: This is like a table that tells us the size of each value. + /// + /// For example: + /// - Tensor 0 might be [32, 784] (a batch of 32 images, each with 784 pixels) + /// - Tensor 1 might be [784, 128] (weights connecting 784 inputs to 128 outputs) + /// - Tensor 2 might be [32, 128] (the result of multiplying tensor 0 and 1) + /// + /// Knowing shapes helps us: + /// - Allocate the right amount of memory + /// - Check that operations are valid (can't multiply incompatible shapes) + /// - Optimize operations for specific sizes + /// + /// + public Dictionary TensorShapes { get; set; } = new(); + + /// + /// Gets or sets the IDs of input tensors to this graph. + /// + /// + /// + /// Input tensors are provided by the caller and are not computed within the graph. + /// They serve as the starting point for all computations. + /// + /// For Beginners: These are the "ingredients" that you provide to start the computation. + /// + /// For a neural network, inputs might be: + /// - The input data (like an image) + /// - Model parameters (weights and biases) + /// + /// The graph will process these inputs through all its operations to produce outputs. + /// + /// + public List InputIds { get; set; } = new(); + + /// + /// Gets or sets the IDs of output tensors produced by this graph. + /// + /// + /// + /// Output tensors are the final results of the graph computation and are + /// returned to the caller. + /// + /// For Beginners: These are the "final dishes" - the results you care about. + /// + /// For a neural network, outputs might be: + /// - Predictions (class probabilities) + /// - Loss value + /// - Intermediate features (for visualization) + /// + /// Everything else in the graph is just intermediate calculations to get to these outputs. + /// + /// + public List OutputIds { get; set; } = new(); + + /// + /// Gets or sets optional metadata about the graph. + /// + public Dictionary Metadata { get; set; } = new(); + + /// + /// Validates the graph structure for correctness. + /// + /// True if the graph is valid, false otherwise. + /// + /// + /// Validation checks include: + /// - All input tensor IDs are defined in TensorShapes + /// - All operation inputs reference valid tensor IDs + /// - No cycles in the graph (it's a DAG) + /// - All output IDs are produced by operations or are inputs + /// + /// For Beginners: This checks that the "recipe" makes sense. + /// + /// It verifies: + /// - You're not using an ingredient that doesn't exist + /// - Steps are in the right order (don't use results before computing them) + /// - The final outputs are actually produced by the recipe + /// + /// If validation fails, something is wrong with how the graph was constructed. + /// + /// + public bool Validate() + { + // Check that all inputs have shapes defined + foreach (var inputId in InputIds) + { + if (!TensorShapes.ContainsKey(inputId)) + { + return false; + } + } + + // Track which tensors have been produced + var producedTensors = new HashSet(InputIds); + + // Check each operation + foreach (var op in Operations) + { + // Validate the operation itself + if (!op.Validate()) + { + return false; + } + + // Check that all inputs have been produced + foreach (var inputId in op.InputIds) + { + if (!producedTensors.Contains(inputId)) + { + return false; // Using a tensor before it's produced + } + } + + // Mark output as produced + producedTensors.Add(op.OutputId); + + // Ensure output shape is defined + if (!TensorShapes.ContainsKey(op.OutputId)) + { + TensorShapes[op.OutputId] = op.OutputShape; + } + } + + // Check that all outputs have been produced + foreach (var outputId in OutputIds) + { + if (!producedTensors.Contains(outputId)) + { + return false; + } + } + + return true; + } + + /// + /// Gets a string representation of the graph for debugging and visualization. + /// + public override string ToString() + { + var sb = new System.Text.StringBuilder(); + sb.AppendLine($"IR Graph:"); + sb.AppendLine($" Inputs: {string.Join(", ", InputIds.Select(id => $"t{id}"))}"); + sb.AppendLine($" Operations ({Operations.Count}):"); + foreach (var op in Operations) + { + sb.AppendLine($" {op}"); + } + sb.AppendLine($" Outputs: {string.Join(", ", OutputIds.Select(id => $"t{id}"))}"); + return sb.ToString(); + } + + /// + /// Computes a hash code for this graph structure (ignoring tensor values). + /// + /// + /// + /// The hash is based on the graph structure: operation types, shapes, and connectivity. + /// This is used for caching compiled graphs - graphs with the same structure can reuse + /// the same compiled code even if the actual tensor values are different. + /// + /// For Beginners: This creates a "fingerprint" for the graph structure. + /// + /// Two graphs with the same fingerprint have the same structure (same operations, + /// same shapes) even if the actual numbers in the tensors are different. + /// + /// This lets us reuse compiled code: + /// - First time: Compile the graph (slow) + /// - Next time with same structure: Reuse compiled code (fast!) + /// + /// It's like having a pre-cooked recipe that you can use with different ingredients. + /// + /// + public int ComputeStructureHash() + { + int hash = 17; + + // Hash input shapes + foreach (var inputId in InputIds.OrderBy(id => id)) + { + hash = hash * 31 + inputId.GetHashCode(); + if (TensorShapes.TryGetValue(inputId, out var shape)) + { + hash = hash * 31 + shape.GetShapeHashCode(); + } + } + + // Hash operations + foreach (var op in Operations) + { + hash = hash * 31 + op.OpType.GetHashCode(); + hash = hash * 31 + op.OutputId.GetHashCode(); + hash = hash * 31 + op.OutputType.GetHashCode(); + hash = hash * 31 + op.OutputShape.GetShapeHashCode(); + + foreach (var inputId in op.InputIds) + { + hash = hash * 31 + inputId.GetHashCode(); + } + } + + // Hash output IDs + foreach (var outputId in OutputIds.OrderBy(id => id)) + { + hash = hash * 31 + outputId.GetHashCode(); + } + + return hash; + } +} diff --git a/src/JitCompiler/IR/IROp.cs b/src/JitCompiler/IR/IROp.cs new file mode 100644 index 000000000..ec75fdd61 --- /dev/null +++ b/src/JitCompiler/IR/IROp.cs @@ -0,0 +1,280 @@ +namespace AiDotNet.JitCompiler.IR; + +/// +/// Base class for all IR operations. +/// +/// +/// +/// IROp represents a single operation in the intermediate representation graph. +/// Each operation has inputs (tensor IDs), produces an output (tensor ID), and +/// has metadata about types and shapes. +/// +/// For Beginners: An IROp is like a single step in a recipe. +/// +/// Each operation: +/// - Takes some inputs (the tensor IDs it needs) +/// - Performs a calculation (add, multiply, etc.) +/// - Produces an output (a new tensor ID) +/// - Knows what type and shape the output will be +/// +/// For example, an "Add" operation might: +/// - Take inputs: tensor 0 and tensor 1 +/// - Perform: element-wise addition +/// - Produce: tensor 2 +/// - Know: output has the same shape as the inputs +/// +/// The JIT compiler uses this information to generate optimized code. +/// +/// +public abstract class IROp +{ + /// + /// Gets or sets the unique identifier for the output of this operation. + /// + /// + /// + /// The output ID identifies the tensor produced by this operation. + /// It's used by subsequent operations to reference this result. + /// + /// For Beginners: This is like a variable name for the result. + /// + /// For example, if this operation computes "c = a + b": + /// - OutputId might be 2 (representing "c") + /// - InputIds might be [0, 1] (representing "a" and "b") + /// + /// Later operations can use tensor 2 as their input. + /// + /// + public int OutputId { get; set; } + + /// + /// Gets or sets the identifiers of the input tensors to this operation. + /// + /// + /// + /// Input IDs reference tensors that must be computed before this operation. + /// They can be graph inputs, constants, or outputs from earlier operations. + /// + /// For Beginners: These are the inputs this operation needs. + /// + /// For a binary operation like addition: + /// - InputIds = [0, 1] means "add tensor 0 and tensor 1" + /// + /// For a unary operation like ReLU: + /// - InputIds = [5] means "apply ReLU to tensor 5" + /// + /// The order matters! For subtraction, [0, 1] means "0 - 1", not "1 - 0". + /// + /// + public int[] InputIds { get; set; } = Array.Empty(); + + /// + /// Gets or sets the data type of the output tensor. + /// + /// + /// + /// The output type determines what numeric type (float, double, int, etc.) + /// the result tensor will use. This affects memory usage and precision. + /// + /// For Beginners: This tells us what kind of numbers the result contains. + /// + /// Common types: + /// - Float32: Single-precision floating point (most common for neural networks) + /// - Float64: Double-precision floating point (higher precision, more memory) + /// - Int32: 32-bit integers + /// + /// The type affects: + /// - Memory usage (float32 uses half the memory of float64) + /// - Precision (how accurate calculations are) + /// - Performance (some operations are faster with certain types) + /// + /// + public IRType OutputType { get; set; } + + /// + /// Gets or sets the shape of the output tensor. + /// + /// + /// + /// The output shape is represented as an int[] array matching the existing + /// Tensor<T>.Shape format. Each element is the size of that dimension. + /// + /// For Beginners: This tells us the size and dimensions of the result. + /// + /// Examples: + /// - [] = scalar (single number) + /// - [10] = vector with 10 elements + /// - [3, 4] = 3×4 matrix + /// - [32, 3, 224, 224] = batch of 32 RGB images, each 224×224 pixels + /// + /// The shape is determined by the operation: + /// - Adding [3, 4] + [3, 4] → [3, 4] (same shape) + /// - Matrix multiply [3, 4] × [4, 5] → [3, 5] (rows from left, cols from right) + /// - Sum [3, 4] along axis 1 → [3] (reduces one dimension) + /// + /// + public int[] OutputShape { get; set; } = Array.Empty(); + + /// + /// Gets the operation type name for debugging and visualization. + /// + /// + /// + /// By default, this returns the class name without the "Op" suffix. + /// For example, "MatMulOp" becomes "MatMul". + /// + /// For Beginners: This is a human-readable name for the operation. + /// + /// Used for: + /// - Debugging (see what operations are in the graph) + /// - Visualization (draw a graph diagram) + /// - Logging (track what the compiler is doing) + /// + /// Examples: "Add", "MatMul", "ReLU", "Conv2D" + /// + /// + public virtual string OpType => GetType().Name.Replace("Op", ""); + + /// + /// Validates that this operation is correctly formed. + /// + /// True if valid, false otherwise. + /// + /// + /// Basic validation checks that the operation has required information. + /// Derived classes can override to add operation-specific validation. + /// + /// For Beginners: This checks that the operation makes sense. + /// + /// Basic checks: + /// - Output ID is valid (non-negative) + /// - Has the right number of inputs + /// - Shapes are compatible + /// + /// Specific operations add their own checks: + /// - MatMul: inner dimensions must match + /// - Conv2D: kernel size must be valid + /// - Reshape: total elements must be preserved + /// + /// If validation fails, the operation can't be compiled. + /// + /// + public virtual bool Validate() + { + // Basic validation: output ID should be non-negative + if (OutputId < 0) + return false; + + // Output shape should be valid + if (OutputShape == null || !OutputShape.IsValidShape()) + return false; + + return true; + } + + /// + /// Gets a string representation of this operation for debugging. + /// + /// A string describing this operation. + /// + /// + /// The string format is: "tOutput = OpType(tInput1, tInput2, ...) : Type [Shape]" + /// + /// For Beginners: This creates a readable description of the operation. + /// + /// Example outputs: + /// - "t2 = Add(t0, t1) : Float32 [3, 4]" + /// - "t5 = MatMul(t3, t4) : Float32 [128, 256]" + /// - "t8 = ReLU(t7) : Float32 [32, 128]" + /// + /// This is super helpful for debugging - you can see exactly what each + /// operation does and what shape tensors flow through the graph. + /// + /// + public override string ToString() + { + var inputs = string.Join(", ", InputIds.Select(id => $"t{id}")); + return $"t{OutputId} = {OpType}({inputs}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Interface for optimization passes that transform IR graphs. +/// +/// +/// +/// Optimization passes take an IR graph and transform it to an equivalent +/// but more efficient version. Examples include constant folding, dead code +/// elimination, and operation fusion. +/// +/// For Beginners: An optimization pass improves the graph without changing what it computes. +/// +/// Think of it like optimizing a recipe: +/// - Original: "Add 1 cup flour. Add another 1 cup flour." +/// - Optimized: "Add 2 cups flour." +/// - Result is the same, but simpler! +/// +/// Common optimizations: +/// - Constant folding: Compute constant expressions at compile time +/// - Dead code elimination: Remove operations whose results aren't used +/// - Operation fusion: Combine multiple operations into one +/// - Common subexpression elimination: Compute repeated expressions only once +/// +/// These make the compiled code faster by: +/// - Doing less work +/// - Using less memory +/// - Better utilizing CPU/GPU resources +/// +/// +public interface IOptimizationPass +{ + /// + /// Applies this optimization pass to an IR graph. + /// + /// The graph to optimize. + /// The optimized graph (may be the same instance or a new one). + /// + /// + /// The optimization must preserve the semantics of the graph - it should + /// produce the same results for the same inputs, just more efficiently. + /// + /// For Beginners: This method transforms the graph to make it faster. + /// + /// The pass: + /// - Examines the graph to find optimization opportunities + /// - Creates a new, more efficient version + /// - Returns the optimized graph + /// + /// The optimized graph computes the same results but runs faster. + /// + /// Multiple passes can be chained: + /// - Original graph + /// - → Constant folding + /// - → Dead code elimination + /// - → Operation fusion + /// - → Optimized graph (much faster!) + /// + /// + IRGraph Optimize(IRGraph graph); + + /// + /// Gets the name of this optimization pass. + /// + /// + /// + /// The name is used for logging and debugging to track which optimizations + /// have been applied to a graph. + /// + /// For Beginners: A human-readable name for this optimization. + /// + /// Examples: + /// - "Constant Folding" + /// - "Dead Code Elimination" + /// - "Operation Fusion" + /// + /// Used when printing optimization logs like: + /// "Applied Constant Folding: reduced 150 ops to 142 ops" + /// + /// + string Name { get; } +} diff --git a/src/JitCompiler/IR/IRType.cs b/src/JitCompiler/IR/IRType.cs new file mode 100644 index 000000000..311963a63 --- /dev/null +++ b/src/JitCompiler/IR/IRType.cs @@ -0,0 +1,71 @@ +namespace AiDotNet.JitCompiler.IR; + +/// +/// Represents the data type of a tensor in the IR. +/// +public enum IRType +{ + Float32, + Float64, + Int32, + Int64, + Byte, + SByte, + Int16, + UInt16, + UInt32, + UInt64, + Decimal, + Half, + Complex +} + +/// +/// Helper methods for IRType. +/// +public static class IRTypeExtensions +{ + /// + /// Gets the IRType for a given System.Type. + /// + public static IRType FromSystemType(Type type) + { + return type switch + { + Type t when t == typeof(float) => IRType.Float32, + Type t when t == typeof(double) => IRType.Float64, + Type t when t == typeof(int) => IRType.Int32, + Type t when t == typeof(long) => IRType.Int64, + Type t when t == typeof(byte) => IRType.Byte, + Type t when t == typeof(sbyte) => IRType.SByte, + Type t when t == typeof(short) => IRType.Int16, + Type t when t == typeof(ushort) => IRType.UInt16, + Type t when t == typeof(uint) => IRType.UInt32, + Type t when t == typeof(ulong) => IRType.UInt64, + Type t when t == typeof(decimal) => IRType.Decimal, + _ => throw new NotSupportedException($"Type {type} not supported in IR") + }; + } + + /// + /// Gets the System.Type for a given IRType. + /// + public static Type ToSystemType(this IRType irType) + { + return irType switch + { + IRType.Float32 => typeof(float), + IRType.Float64 => typeof(double), + IRType.Int32 => typeof(int), + IRType.Int64 => typeof(long), + IRType.Byte => typeof(byte), + IRType.SByte => typeof(sbyte), + IRType.Int16 => typeof(short), + IRType.UInt16 => typeof(ushort), + IRType.UInt32 => typeof(uint), + IRType.UInt64 => typeof(ulong), + IRType.Decimal => typeof(decimal), + _ => throw new NotSupportedException($"IRType {irType} conversion not supported") + }; + } +} diff --git a/src/JitCompiler/IR/Operations/ActivationOps.cs b/src/JitCompiler/IR/Operations/ActivationOps.cs new file mode 100644 index 000000000..99164fcac --- /dev/null +++ b/src/JitCompiler/IR/Operations/ActivationOps.cs @@ -0,0 +1,731 @@ +namespace AiDotNet.JitCompiler.IR.Operations; + +/// +/// Represents ReLU (Rectified Linear Unit) activation in the IR. +/// +/// +/// +/// Corresponds to TensorOperations.ReLU(). +/// Computes max(0, x) for each element: result[i] = max(0, a[i]). +/// +/// For Beginners: Keeps positive values, zeros out negative values. +/// +/// Example: +/// ReLU([-2, -1, 0, 1, 2]) = [0, 0, 0, 1, 2] +/// +/// Very common in neural networks because it's simple and effective. +/// +/// +public class ReLUOp : IROp +{ + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } +} + +/// +/// Represents Sigmoid activation in the IR. +/// +/// +/// +/// Corresponds to TensorOperations.Sigmoid(). +/// Computes sigmoid function: result[i] = 1 / (1 + exp(-a[i])). +/// Output range is (0, 1). +/// +/// For Beginners: Squashes values to between 0 and 1. +/// +/// Example: +/// Sigmoid([-∞, -2, 0, 2, ∞]) ≈ [0, 0.12, 0.5, 0.88, 1] +/// +/// Used for binary classification (outputs can be interpreted as probabilities). +/// +/// +public class SigmoidOp : IROp +{ + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } +} + +/// +/// Represents Tanh (hyperbolic tangent) activation in the IR. +/// +/// +/// +/// Corresponds to TensorOperations.Tanh(). +/// Computes tanh function: result[i] = (exp(a[i]) - exp(-a[i])) / (exp(a[i]) + exp(-a[i])). +/// Output range is (-1, 1). +/// +/// For Beginners: Squashes values to between -1 and 1. +/// +/// Example: +/// Tanh([-∞, -2, 0, 2, ∞]) ≈ [-1, -0.96, 0, 0.96, 1] +/// +/// Similar to sigmoid but centered at zero, often works better than sigmoid. +/// +/// +public class TanhOp : IROp +{ + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } +} + +/// +/// Represents Softmax activation in the IR. +/// +/// +/// +/// Corresponds to TensorOperations.Softmax(). +/// Computes softmax along specified axis. Converts logits to probabilities. +/// +/// For Beginners: Converts scores to probabilities that sum to 1. +/// +/// Example: +/// Softmax([1, 2, 3]) ≈ [0.09, 0.24, 0.67] +/// (notice they sum to 1.0) +/// +/// Used for multi-class classification - outputs can be interpreted as +/// class probabilities. +/// +/// +public class SoftmaxOp : IROp +{ + /// + /// The axis along which to compute softmax. Default is -1 (last axis). + /// + public int Axis { get; set; } = -1; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = Softmax(t{InputIds[0]}, axis={Axis}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Represents a generic activation function application in the IR. +/// +/// +/// +/// Corresponds to TensorOperations.ApplyActivation(). +/// Applies a named activation function to the input. +/// +/// For Beginners: Applies any activation function by name. +/// +/// This is a more generic operation that can apply various activations +/// (ReLU, Sigmoid, Tanh, etc.) based on a parameter rather than being +/// hard-coded to one specific activation. +/// +/// +public class ApplyActivationOp : IROp +{ + /// + /// The name of the activation function to apply. + /// + public string ActivationName { get; set; } = string.Empty; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + if (string.IsNullOrWhiteSpace(ActivationName)) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = ApplyActivation(t{InputIds[0]}, \"{ActivationName}\") : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Represents Softmin activation in the IR (min-based variant of softmax). +/// +/// +/// +/// Computes softmin along specified axis: softmin(x) = softmax(-x). +/// Converts negative logits to probabilities that sum to 1. +/// +/// For Beginners: Like softmax, but emphasizes smaller values. +/// +/// Example: +/// Softmin([1, 2, 3]) approximately equals [0.67, 0.24, 0.09] +/// (notice the smallest value gets the highest probability) +/// +/// Less common than softmax, but useful when minimizing is desired. +/// +/// +public class SoftminOp : IROp +{ + /// + /// The axis along which to compute softmin. Default is -1 (last axis). + /// + public int Axis { get; set; } = -1; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = Softmin(t{InputIds[0]}, axis={Axis}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Represents LogSoftmax activation in the IR (numerically stable). +/// +/// +/// +/// Computes log(softmax(x)) using log-sum-exp trick for numerical stability. +/// Equivalent to log(softmax(x)) but avoids overflow/underflow. +/// +/// For Beginners: Logarithm of softmax probabilities. +/// +/// Example: +/// LogSoftmax([1, 2, 3]) approximately equals [-2.41, -1.41, -0.41] +/// +/// More numerically stable than computing log(softmax(x)) separately. +/// Often used with negative log-likelihood loss in classification. +/// +/// +public class LogSoftmaxOp : IROp +{ + /// + /// The axis along which to compute log-softmax. Default is -1 (last axis). + /// + public int Axis { get; set; } = -1; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = LogSoftmax(t{InputIds[0]}, axis={Axis}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Represents LogSoftmin activation in the IR (numerically stable). +/// +/// +/// +/// Computes log(softmin(x)) using log-sum-exp trick for numerical stability. +/// +/// For Beginners: Logarithm of softmin probabilities. +/// +/// Example: +/// LogSoftmin([1, 2, 3]) approximately equals [-0.41, -1.41, -2.41] +/// +/// Numerically stable version of log(softmin(x)). +/// +/// +public class LogSoftminOp : IROp +{ + /// + /// The axis along which to compute log-softmin. Default is -1 (last axis). + /// + public int Axis { get; set; } = -1; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = LogSoftmin(t{InputIds[0]}, axis={Axis}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Represents Sparsemax activation in the IR (sparse alternative to softmax). +/// +/// +/// +/// Computes sparsemax projection: produces sparse probability distributions. +/// Unlike softmax, can produce exact zeros for low-probability classes. +/// +/// For Beginners: Like softmax, but can produce exact zeros. +/// +/// Example: +/// Sparsemax([1, 2, 7]) approximately equals [0, 0, 1] +/// (notice exact zeros for unlikely classes) +/// +/// Useful when you want sparse predictions (most classes with zero probability). +/// +/// TODO: Implement efficient sparsemax algorithm. +/// Current implementation is placeholder - requires O(n log n) projection algorithm. +/// +/// +public class SparsemaxOp : IROp +{ + /// + /// The axis along which to compute sparsemax. Default is -1 (last axis). + /// + public int Axis { get; set; } = -1; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = Sparsemax(t{InputIds[0]}, axis={Axis}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Represents Spherical Softmax activation in the IR (softmax on unit sphere). +/// +/// +/// +/// Computes softmax after normalizing input vectors to unit sphere. +/// Useful for angular-based representations. +/// +/// For Beginners: Softmax applied to normalized vectors. +/// +/// First normalizes each vector to unit length, then applies softmax. +/// Useful when direction matters more than magnitude. +/// +/// +public class SphericalSoftmaxOp : IROp +{ + /// + /// The axis along which to compute spherical softmax. Default is -1 (last axis). + /// + public int Axis { get; set; } = -1; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = SphericalSoftmax(t{InputIds[0]}, axis={Axis}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Represents Gumbel-Softmax activation in the IR (stochastic, differentiable). +/// +/// +/// +/// Computes Gumbel-Softmax: softmax((x + Gumbel noise) / temperature). +/// Provides differentiable sampling from categorical distributions. +/// +/// For Beginners: Softmax with controllable randomness. +/// +/// Adds Gumbel noise before softmax to enable stochastic discrete choices +/// while maintaining differentiability. Temperature controls randomness. +/// +/// Used in variational autoencoders and discrete latent variable models. +/// +/// +public class GumbelSoftmaxOp : IROp +{ + /// + /// Temperature parameter controlling randomness. Lower = more deterministic. + /// + public double Temperature { get; set; } = 1.0; + + /// + /// The axis along which to compute Gumbel-Softmax. Default is -1 (last axis). + /// + public int Axis { get; set; } = -1; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + if (Temperature <= 0) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = GumbelSoftmax(t{InputIds[0]}, temp={Temperature}, axis={Axis}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Represents Taylor-Softmax activation in the IR (Taylor series approximation). +/// +/// +/// +/// Approximates softmax using Taylor series expansion. +/// Faster but less accurate than standard softmax. +/// +/// For Beginners: Fast approximation of softmax. +/// +/// Uses polynomial approximation instead of expensive exponentials. +/// Trades accuracy for speed - good for low-precision applications. +/// +/// TODO: Implement Taylor series approximation. +/// Current implementation is placeholder - requires order parameter for series. +/// +/// +public class TaylorSoftmaxOp : IROp +{ + /// + /// Order of Taylor series approximation. Higher = more accurate, slower. + /// + public int Order { get; set; } = 2; + + /// + /// The axis along which to compute Taylor-Softmax. Default is -1 (last axis). + /// + public int Axis { get; set; } = -1; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + if (Order < 1) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = TaylorSoftmax(t{InputIds[0]}, order={Order}, axis={Axis}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Represents Hierarchical Softmax activation in the IR (tree-structured). +/// +/// +/// +/// Computes hierarchical softmax using binary tree structure. +/// Reduces computational complexity from O(n) to O(log n). +/// +/// For Beginners: Efficient softmax for many classes. +/// +/// Instead of computing probabilities for all classes at once, +/// makes binary decisions in a tree structure. +/// +/// Much faster when number of classes is very large (e.g., vocabulary in NLP). +/// +/// TODO: Implement hierarchical tree structure. +/// Current implementation is placeholder - requires tree specification. +/// +/// +public class HierarchicalSoftmaxOp : IROp +{ + /// + /// Tree structure specification (placeholder - needs design). + /// + public string TreeStructure { get; set; } = string.Empty; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = HierarchicalSoftmax(t{InputIds[0]}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Represents Maxout activation in the IR. +/// +/// +/// +/// Computes max(W1*x + b1, W2*x + b2, ...) across multiple linear projections. +/// Learns the activation function itself through multiple weight sets. +/// +/// For Beginners: Takes maximum across multiple linear transformations. +/// +/// Instead of applying a fixed function like ReLU, computes several +/// linear functions and takes the max. The network learns which function +/// shape works best. +/// +/// More powerful but requires more parameters than standard activations. +/// +/// +public class MaxoutOp : IROp +{ + /// + /// Number of linear projections to max over. + /// + public int NumProjections { get; set; } = 2; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length < 1) return false; + if (NumProjections < 2) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = Maxout(t{InputIds[0]}, projections={NumProjections}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Represents Sign activation in the IR. +/// +/// +/// +/// Computes sign function: -1 for negative, 0 for zero, +1 for positive. +/// +/// For Beginners: Outputs only -1, 0, or +1. +/// +/// Example: +/// Sign([-5.3, -0.1, 0, 0.1, 5.3]) = [-1, -1, 0, 1, 1] +/// +/// Used in binary neural networks and sign-based optimization. +/// Not differentiable at zero, so requires special gradient handling. +/// +/// +public class SignOp : IROp +{ + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } +} + +/// +/// Represents Gaussian activation in the IR. +/// +/// +/// +/// Computes Gaussian function: exp(-x^2). +/// Bell-shaped curve centered at zero. +/// +/// For Beginners: Bell curve activation. +/// +/// Example: +/// Gaussian([-2, -1, 0, 1, 2]) approximately equals [0.02, 0.37, 1.0, 0.37, 0.02] +/// +/// Maximum at zero, decreases towards zero as x moves away from origin. +/// Used in radial basis function networks. +/// +/// +public class GaussianOp : IROp +{ + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } +} + +/// +/// Represents ISRU (Inverse Square Root Unit) activation in the IR. +/// +/// +/// +/// Computes ISRU: x / sqrt(1 + alpha * x^2). +/// Self-normalizing activation similar to ELU but faster. +/// +/// For Beginners: Smooth, bounded activation function. +/// +/// Example (alpha=1): +/// ISRU([-2, -1, 0, 1, 2]) approximately equals [-0.89, -0.71, 0, 0.71, 0.89] +/// +/// Output range is approximately (-1/sqrt(alpha), 1/sqrt(alpha)). +/// Faster than ELU because it avoids exponentials. +/// +/// +public class ISRUOp : IROp +{ + /// + /// Alpha parameter controlling the curve shape. Default is 1.0. + /// + public double Alpha { get; set; } = 1.0; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + if (Alpha <= 0) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = ISRU(t{InputIds[0]}, alpha={Alpha}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Represents LiSHT (Linearly Scaled Hyperbolic Tangent) activation in the IR. +/// +/// +/// +/// Computes LiSHT: x * tanh(x). +/// Combines linear and tanh properties. +/// +/// For Beginners: Smooth, non-monotonic activation. +/// +/// Example: +/// LiSHT([-2, -1, 0, 1, 2]) approximately equals [-1.93, -0.76, 0, 0.76, 1.93] +/// +/// Similar to Swish but uses tanh instead of sigmoid. +/// Has a small negative region and grows almost linearly for large x. +/// +/// +public class LiSHTOp : IROp +{ + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } +} + +/// +/// Represents SQRBF (Squared Radial Basis Function) activation in the IR. +/// +/// +/// +/// Computes squared RBF: exp(-beta * x^2). +/// Gaussian-like activation with adjustable width. +/// +/// For Beginners: Adjustable bell curve. +/// +/// Example (beta=1): +/// SQRBF([-2, -1, 0, 1, 2]) approximately equals [0.02, 0.37, 1.0, 0.37, 0.02] +/// +/// Beta controls the width of the bell curve. +/// Used in radial basis function networks for local learning. +/// +/// +public class SQRBFOp : IROp +{ + /// + /// Beta parameter controlling the RBF width. Default is 1.0. + /// + public double Beta { get; set; } = 1.0; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + if (Beta <= 0) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = SQRBF(t{InputIds[0]}, beta={Beta}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Represents Squash activation in the IR (capsule network squashing). +/// +/// +/// +/// Computes squashing function: (||x||^2 / (1 + ||x||^2)) * (x / ||x||). +/// Squashes vector length to [0, 1) while preserving direction. +/// +/// For Beginners: Normalizes vector length to less than 1. +/// +/// Used in capsule networks to represent presence of features. +/// - Long vectors stay long (approach length 1) +/// - Short vectors get shorter (approach length 0) +/// - Direction is always preserved +/// +/// Unlike softmax, works on vector magnitudes, not individual elements. +/// +/// +public class SquashOp : IROp +{ + /// + /// The axis along which to compute vector norms. Default is -1 (last axis). + /// + public int Axis { get; set; } = -1; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = Squash(t{InputIds[0]}, axis={Axis}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} + +/// +/// Represents Binary Spiking Activation in the IR (for spiking neural networks). +/// +/// +/// +/// Computes binary step function with threshold: output = (x >= threshold) ? 1 : 0. +/// Used in spiking neural networks to model neuron firing. +/// +/// For Beginners: Outputs 1 if above threshold, 0 otherwise. +/// +/// Example (threshold=0.5): +/// BinarySpike([0.1, 0.5, 0.9, 1.5]) = [0, 1, 1, 1] +/// +/// Models biological neurons that fire when membrane potential exceeds threshold. +/// Not differentiable, requires surrogate gradients for training. +/// +/// +public class BinarySpikingActivationOp : IROp +{ + /// + /// Firing threshold. Default is 0.5. + /// + public double Threshold { get; set; } = 0.5; + + public override bool Validate() + { + if (!base.Validate()) return false; + if (InputIds.Length != 1) return false; + return true; + } + + public override string ToString() + { + return $"t{OutputId} = BinarySpike(t{InputIds[0]}, threshold={Threshold}) : {OutputType} {OutputShape.ShapeToString()}"; + } +} diff --git a/src/JitCompiler/IR/TensorShape.cs b/src/JitCompiler/IR/TensorShape.cs new file mode 100644 index 000000000..8e6ea8ca3 --- /dev/null +++ b/src/JitCompiler/IR/TensorShape.cs @@ -0,0 +1,313 @@ +using AiDotNet.LinearAlgebra; + +namespace AiDotNet.JitCompiler.IR; + +/// +/// Provides extension methods and utilities for working with tensor shapes in the IR. +/// +/// +/// +/// This class provides helper methods for working with tensor shapes (represented as int[] arrays). +/// It integrates with the existing Tensor<T> infrastructure which already uses int[] for shapes. +/// +/// For Beginners: In AiDotNet, tensor shapes are represented as integer arrays. +/// +/// For example: +/// - [5] is a vector with 5 elements +/// - [3, 4] is a 3×4 matrix +/// - [2, 3, 4] is a 3D tensor +/// +/// This class provides utilities to work with these shapes: +/// - Check if two shapes are compatible for operations +/// - Compute the result shape when broadcasting +/// - Validate shapes +/// - Compare shapes +/// +/// These utilities are used by the JIT compiler to understand tensor dimensions +/// and generate optimized code. +/// +/// +public static class TensorShapeExtensions +{ + /// + /// Computes the total number of elements in a tensor with the given shape. + /// + /// The tensor shape. + /// The total number of elements, or -1 if any dimension is dynamic. + /// + /// For Beginners: This calculates how many total values a tensor holds. + /// + /// For example: + /// - [5] has 5 elements + /// - [3, 4] has 3 × 4 = 12 elements + /// - [2, 3, 4] has 2 × 3 × 4 = 24 elements + /// + /// If any dimension is -1 (meaning "dynamic" or "unknown"), returns -1. + /// + /// + public static int GetElementCount(this int[] shape) + { + if (shape.Length == 0) return 0; + + int count = 1; + foreach (var dim in shape) + { + if (dim < 0) return -1; // Dynamic dimension + count *= dim; + } + return count; + } + + /// + /// Gets the rank (number of dimensions) of a tensor shape. + /// + /// The tensor shape. + /// The number of dimensions. + /// + /// For Beginners: The rank is how many dimensions the tensor has. + /// + /// - [5] has rank 1 (a vector) + /// - [3, 4] has rank 2 (a matrix) + /// - [2, 3, 4] has rank 3 (a 3D tensor) + /// - [] has rank 0 (a scalar - single number) + /// + /// + public static int GetRank(this int[] shape) => shape.Length; + + /// + /// Checks if this shape is compatible with another shape for broadcasting. + /// + /// The first shape. + /// The second shape. + /// True if the shapes are compatible for broadcasting. + /// + /// + /// Broadcasting allows operations between tensors of different shapes by automatically + /// expanding dimensions. Two shapes are compatible if: + /// - They have the same rank and all dimensions match, OR + /// - One dimension is 1 (can be broadcast), OR + /// - One tensor has fewer dimensions (will be expanded) + /// + /// For Beginners: Broadcasting lets you do operations on tensors of different sizes. + /// + /// For example: + /// - [3, 4] and [3, 4] are compatible (same shape) + /// - [3, 4] and [1, 4] are compatible (first dimension broadcasts) + /// - [3, 4] and [4] are compatible (vector broadcasts across all rows) + /// - [3, 4] and [3, 5] are NOT compatible (incompatible dimensions) + /// + /// This is very useful in neural networks where you often add a bias vector to every + /// row of a matrix - broadcasting handles this automatically. + /// + /// + public static bool IsCompatibleWith(this int[] shape1, int[] shape2) + { + if (shape1 == null || shape2 == null) return false; + + // Scalars are compatible with everything + if (shape1.Length == 0 || shape2.Length == 0) return true; + + // Check from right to left (trailing dimensions) + int maxRank = Math.Max(shape1.Length, shape2.Length); + for (int i = 1; i <= maxRank; i++) + { + int dim1 = i <= shape1.Length ? shape1[shape1.Length - i] : 1; + int dim2 = i <= shape2.Length ? shape2[shape2.Length - i] : 1; + + // Dimensions must be equal, one must be 1 (broadcast), or -1 (dynamic) + if (dim1 != dim2 && dim1 != 1 && dim2 != 1 && dim1 != -1 && dim2 != -1) + { + return false; + } + } + + return true; + } + + /// + /// Computes the broadcast shape resulting from combining two shapes. + /// + /// The first shape. + /// The second shape. + /// The broadcast result shape. + /// Thrown if shapes are not compatible. + /// + /// + /// The broadcast shape is computed by taking the maximum dimension at each position + /// when comparing from right to left. + /// + /// For Beginners: This calculates what shape results when broadcasting two tensors. + /// + /// Examples: + /// - [3, 4] + [3, 4] → [3, 4] (same shape) + /// - [3, 4] + [1, 4] → [3, 4] (first dimension expands from 1 to 3) + /// - [3, 4] + [4] → [3, 4] (vector broadcasts to match all rows) + /// - [5, 3, 4] + [4] → [5, 3, 4] (vector broadcasts across all 5×3 positions) + /// + /// The result tells us what shape the output will have after the operation. + /// + /// + public static int[] BroadcastWith(this int[] shape1, int[] shape2) + { + if (!shape1.IsCompatibleWith(shape2)) + { + throw new InvalidOperationException( + $"Shapes [{string.Join(", ", shape1)}] and [{string.Join(", ", shape2)}] " + + $"are not compatible for broadcasting"); + } + + int maxRank = Math.Max(shape1.Length, shape2.Length); + int[] resultShape = new int[maxRank]; + + for (int i = 1; i <= maxRank; i++) + { + int dim1 = i <= shape1.Length ? shape1[shape1.Length - i] : 1; + int dim2 = i <= shape2.Length ? shape2[shape2.Length - i] : 1; + + // Take maximum (handle dynamic dimensions) + if (dim1 == -1 || dim2 == -1) + { + resultShape[maxRank - i] = -1; // Dynamic + } + else + { + resultShape[maxRank - i] = Math.Max(dim1, dim2); + } + } + + return resultShape; + } + + /// + /// Checks if two shapes are exactly equal. + /// + /// The first shape. + /// The second shape. + /// True if shapes are equal. + /// + /// For Beginners: This checks if two shapes are identical. + /// + /// Examples: + /// - [3, 4] equals [3, 4] → true + /// - [3, 4] equals [4, 3] → false (different order!) + /// - [3, 4] equals [1, 4] → false (different dimensions) + /// + /// + public static bool ShapesEqual(int[]? shape1, int[]? shape2) + { + if (ReferenceEquals(shape1, shape2)) return true; + if (shape1 == null || shape2 == null) return false; + if (shape1.Length != shape2.Length) return false; + + for (int i = 0; i < shape1.Length; i++) + { + if (shape1[i] != shape2[i]) + return false; + } + + return true; + } + + /// + /// Creates a string representation of a shape. + /// + /// The shape to represent. + /// A string representation. + /// + /// For Beginners: This converts a shape to a readable string for debugging. + /// + /// Examples: + /// - [] → "scalar" + /// - [5] → "[5]" + /// - [3, 4] → "[3, 4]" + /// - [2, -1, 4] → "[2, ?, 4]" (? means dynamic) + /// + /// + public static string ShapeToString(this int[] shape) + { + if (shape.Length == 0) return "scalar"; + return $"[{string.Join(", ", shape.Select(d => d >= 0 ? d.ToString() : "?"))}]"; + } + + /// + /// Computes a hash code for a tensor shape. + /// + /// The shape to hash. + /// A hash code. + /// + /// + /// This hash code can be used to cache compiled graphs based on shape. + /// Shapes with the same dimensions will have the same hash. + /// + /// For Beginners: This creates a unique number that represents the shape. + /// + /// It's like a fingerprint for the shape - two identical shapes will have + /// the same hash code. This is used to quickly check if we've already compiled + /// code for a tensor of this shape, so we can reuse it instead of recompiling. + /// + /// + public static int GetShapeHashCode(this int[] shape) + { + int hash = 17; + foreach (var dim in shape) + { + hash = hash * 31 + dim.GetHashCode(); + } + return hash; + } + + /// + /// Extracts the shape from a Tensor. + /// + /// The numeric type of the tensor. + /// The tensor. + /// The shape as an int array. + /// + /// For Beginners: This gets the shape from an existing Tensor object. + /// + /// Since Tensor already has a Shape property, this just returns it. + /// It's provided for consistency with the IR infrastructure. + /// + /// + public static int[] GetShape(this Tensor tensor) + { + return tensor.Shape; + } + + /// + /// Validates that a shape is well-formed. + /// + /// The shape to validate. + /// True if valid. + /// + /// + /// A shape is valid if all dimensions are either positive or -1 (dynamic). + /// Zero dimensions are not allowed. + /// + /// For Beginners: This checks that a shape makes sense. + /// + /// Valid shapes: + /// - [] (scalar) + /// - [5] (vector with 5 elements) + /// - [3, 4] (3×4 matrix) + /// - [-1, 4] (dynamic first dimension, 4 columns) + /// + /// Invalid shapes: + /// - [0, 4] (can't have zero dimension) + /// - [3, -2] (only -1 is allowed for dynamic) + /// + /// + public static bool IsValidShape(this int[] shape) + { + if (shape == null) return false; + + foreach (var dim in shape) + { + // Dimensions must be positive or -1 (dynamic) + if (dim <= 0 && dim != -1) + return false; + } + + return true; + } +}