diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel index 43ef1965e..a0423310c 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel @@ -18,6 +18,7 @@ java_library( ":eval_const", ":eval_create_list", ":eval_create_map", + ":eval_create_struct", ":planned_program", "//:auto_value", "//common:cel_ast", @@ -26,6 +27,7 @@ java_library( "//common/ast", "//common/types", "//common/types:type_providers", + "//common/values:cel_value_provider", "//runtime:evaluation_exception", "//runtime:evaluation_exception_builder", "//runtime:interpretable", @@ -93,6 +95,22 @@ java_library( ], ) +java_library( + name = "eval_create_struct", + srcs = ["EvalCreateStruct.java"], + deps = [ + "//common/types", + "//common/values", + "//common/values:cel_value_provider", + "//runtime:evaluation_exception", + "//runtime:evaluation_listener", + "//runtime:function_resolver", + "//runtime:interpretable", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + ], +) + java_library( name = "eval_create_list", srcs = ["EvalCreateList.java"], diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java new file mode 100644 index 000000000..5c1f1d77b --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java @@ -0,0 +1,104 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import com.google.errorprone.annotations.Immutable; +import dev.cel.common.types.StructType; +import dev.cel.common.values.CelValueProvider; +import dev.cel.common.values.StructValue; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelEvaluationListener; +import dev.cel.runtime.CelFunctionResolver; +import dev.cel.runtime.GlobalResolver; +import dev.cel.runtime.Interpretable; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +@Immutable +final class EvalCreateStruct implements Interpretable { + + private final CelValueProvider valueProvider; + private final StructType structType; + + // Array contents are not mutated + @SuppressWarnings("Immutable") + private final String[] keys; + + // Array contents are not mutated + @SuppressWarnings("Immutable") + private final Interpretable[] values; + + @Override + public Object eval(GlobalResolver resolver) throws CelEvaluationException { + Map fieldValues = new HashMap<>(); + for (int i = 0; i < keys.length; i++) { + Object value = values[i].eval(resolver); + fieldValues.put(keys[i], value); + } + + // Either a primitive (wrappers) or a struct is produced + Object value = + valueProvider + .newValue(structType.name(), Collections.unmodifiableMap(fieldValues)) + .orElseThrow(() -> new IllegalArgumentException("Type name not found: " + structType)); + + if (value instanceof StructValue) { + return ((StructValue) value).value(); + } + + return value; + } + + @Override + public Object eval(GlobalResolver resolver, CelEvaluationListener listener) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + @Override + public Object eval(GlobalResolver resolver, CelFunctionResolver lateBoundFunctionResolver) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + @Override + public Object eval( + GlobalResolver resolver, + CelFunctionResolver lateBoundFunctionResolver, + CelEvaluationListener listener) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + static EvalCreateStruct create( + CelValueProvider valueProvider, + StructType structType, + String[] keys, + Interpretable[] values) { + return new EvalCreateStruct(valueProvider, structType, keys, values); + } + + private EvalCreateStruct( + CelValueProvider valueProvider, + StructType structType, + String[] keys, + Interpretable[] values) { + this.valueProvider = valueProvider; + this.structType = structType; + this.keys = keys; + this.values = values; + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java index 6f1f122d4..ec71155d6 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java @@ -25,11 +25,15 @@ import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.CelList; import dev.cel.common.ast.CelExpr.CelMap; +import dev.cel.common.ast.CelExpr.CelStruct; +import dev.cel.common.ast.CelExpr.CelStruct.Entry; import dev.cel.common.ast.CelReference; import dev.cel.common.types.CelKind; import dev.cel.common.types.CelType; import dev.cel.common.types.CelTypeProvider; +import dev.cel.common.types.StructType; import dev.cel.common.types.TypeType; +import dev.cel.common.values.CelValueProvider; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelEvaluationExceptionBuilder; import dev.cel.runtime.Interpretable; @@ -45,6 +49,7 @@ public final class ProgramPlanner { private final CelTypeProvider typeProvider; + private final CelValueProvider valueProvider; private final AttributeFactory attributeFactory; /** @@ -70,6 +75,8 @@ private Interpretable plan(CelExpr celExpr, PlannerContext ctx) { return planIdent(celExpr, ctx); case LIST: return planCreateList(celExpr, ctx); + case STRUCT: + return planCreateStruct(celExpr, ctx); case MAP: return planCreateMap(celExpr, ctx); case NOT_SET: @@ -131,6 +138,32 @@ private Interpretable planCheckedIdent( return EvalAttribute.create(attributeFactory.newAbsoluteAttribute(identRef.name())); } + private Interpretable planCreateStruct(CelExpr celExpr, PlannerContext ctx) { + CelStruct struct = celExpr.struct(); + CelType structType = + typeProvider + .findType(struct.messageName()) + .orElseThrow( + () -> new IllegalArgumentException("Undefined type name: " + struct.messageName())); + if (!structType.kind().equals(CelKind.STRUCT)) { + throw new IllegalArgumentException( + String.format( + "Expected struct type for %s, got %s", structType.name(), structType.kind())); + } + + ImmutableList entries = struct.entries(); + String[] keys = new String[entries.size()]; + Interpretable[] values = new Interpretable[entries.size()]; + + for (int i = 0; i < entries.size(); i++) { + Entry entry = entries.get(i); + keys[i] = entry.fieldKey(); + values[i] = plan(entry.value(), ctx); + } + + return EvalCreateStruct.create(valueProvider, (StructType) structType, keys, values); + } + private Interpretable planCreateList(CelExpr celExpr, PlannerContext ctx) { CelList list = celExpr.list(); @@ -172,12 +205,14 @@ private static PlannerContext create(CelAbstractSyntaxTree ast) { } } - public static ProgramPlanner newPlanner(CelTypeProvider typeProvider) { - return new ProgramPlanner(typeProvider); + public static ProgramPlanner newPlanner( + CelTypeProvider typeProvider, CelValueProvider valueProvider) { + return new ProgramPlanner(typeProvider, valueProvider); } - private ProgramPlanner(CelTypeProvider typeProvider) { + private ProgramPlanner(CelTypeProvider typeProvider, CelValueProvider valueProvider) { this.typeProvider = typeProvider; + this.valueProvider = valueProvider; // TODO: Container support this.attributeFactory = AttributeFactory.newAttributeFactory(CelContainer.newBuilder().build(), typeProvider); diff --git a/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel index 46ef0f279..0b36e20ac 100644 --- a/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel @@ -15,14 +15,21 @@ java_library( deps = [ "//:java_truth", "//common:cel_ast", + "//common:cel_descriptor_util", "//common:cel_source", + "//common:options", "//common/ast", + "//common/internal:cel_descriptor_pools", + "//common/internal:default_message_factory", + "//common/internal:dynamic_proto", "//common/types", "//common/types:default_type_provider", "//common/types:message_type_provider", "//common/types:type_providers", "//common/values", "//common/values:cel_byte_string", + "//common/values:cel_value_provider", + "//common/values:proto_message_value_provider", "//compiler", "//compiler:compiler_builder", "//extensions", diff --git a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java index 8d119c3c3..32651d73d 100644 --- a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java +++ b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java @@ -25,8 +25,14 @@ import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelDescriptorUtil; +import dev.cel.common.CelOptions; import dev.cel.common.CelSource; import dev.cel.common.ast.CelExpr; +import dev.cel.common.internal.CelDescriptorPool; +import dev.cel.common.internal.DefaultDescriptorPool; +import dev.cel.common.internal.DefaultMessageFactory; +import dev.cel.common.internal.DynamicProto; import dev.cel.common.types.CelType; import dev.cel.common.types.CelTypeProvider; import dev.cel.common.types.CelTypeProvider.CombinedCelTypeProvider; @@ -38,7 +44,9 @@ import dev.cel.common.types.SimpleType; import dev.cel.common.types.TypeType; import dev.cel.common.values.CelByteString; +import dev.cel.common.values.CelValueProvider; import dev.cel.common.values.NullValue; +import dev.cel.common.values.ProtoMessageValueProvider; import dev.cel.compiler.CelCompiler; import dev.cel.compiler.CelCompilerFactory; import dev.cel.expr.conformance.proto3.GlobalEnum; @@ -56,8 +64,17 @@ public final class ProgramPlannerTest { new CombinedCelTypeProvider( DefaultTypeProvider.getInstance(), new ProtoMessageTypeProvider(ImmutableSet.of(TestAllTypes.getDescriptor()))); - - private static final ProgramPlanner PLANNER = ProgramPlanner.newPlanner(TYPE_PROVIDER); + private static final CelDescriptorPool DESCRIPTOR_POOL = + DefaultDescriptorPool.create( + CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( + TestAllTypes.getDescriptor().getFile())); + private static final DynamicProto DYNAMIC_PROTO = + DynamicProto.create(DefaultMessageFactory.create(DESCRIPTOR_POOL)); + private static final CelValueProvider VALUE_PROVIDER = + ProtoMessageValueProvider.newInstance(CelOptions.DEFAULT, DYNAMIC_PROTO); + + private static final ProgramPlanner PLANNER = + ProgramPlanner.newPlanner(TYPE_PROVIDER, VALUE_PROVIDER); private static final CelCompiler CEL_COMPILER = CelCompilerFactory.standardCelCompilerBuilder() .addVar("int_var", SimpleType.INT) @@ -113,6 +130,24 @@ public void plan_ident_variable() throws Exception { assertThat(result).isEqualTo(1); } + @Test + public void planIdent_typeLiteral(@TestParameter TypeLiteralTestCase testCase) throws Exception { + if (isParseOnly) { + if (testCase.equals(TypeLiteralTestCase.DURATION) + || testCase.equals(TypeLiteralTestCase.TIMESTAMP) + || testCase.equals(TypeLiteralTestCase.PROTO_MESSAGE_TYPE)) { + // TODO Skip for now, requires attribute qualification + return; + } + } + CelAbstractSyntaxTree ast = compile(testCase.expression); + Program program = PLANNER.plan(ast); + + TypeType result = (TypeType) program.eval(); + + assertThat(result).isEqualTo(testCase.type); + } + @Test @SuppressWarnings("unchecked") // test only public void plan_createList() throws Exception { @@ -136,21 +171,39 @@ public void plan_createMap() throws Exception { } @Test - public void planIdent_typeLiteral(@TestParameter TypeLiteralTestCase testCase) throws Exception { - if (isParseOnly) { - if (testCase.equals(TypeLiteralTestCase.DURATION) - || testCase.equals(TypeLiteralTestCase.TIMESTAMP) - || testCase.equals(TypeLiteralTestCase.PROTO_MESSAGE_TYPE)) { - // TODO Skip for now, requires attribute qualification - return; - } - } - CelAbstractSyntaxTree ast = compile(testCase.expression); + public void plan_createStruct() throws Exception { + CelAbstractSyntaxTree ast = compile("cel.expr.conformance.proto3.TestAllTypes{}"); Program program = PLANNER.plan(ast); - TypeType result = (TypeType) program.eval(); + TestAllTypes result = (TestAllTypes) program.eval(); - assertThat(result).isEqualTo(testCase.type); + assertThat(result).isEqualTo(TestAllTypes.getDefaultInstance()); + } + + @Test + public void plan_createStruct_wrapper() throws Exception { + CelAbstractSyntaxTree ast = compile("google.protobuf.StringValue { value: 'foo' }"); + Program program = PLANNER.plan(ast); + + String result = (String) program.eval(); + + assertThat(result).isEqualTo("foo"); + } + + @Test + public void planCreateStruct_withFields() throws Exception { + CelAbstractSyntaxTree ast = + compile( + "cel.expr.conformance.proto3.TestAllTypes{" + + "single_string: 'foo'," + + "single_bool: true" + + "}"); + Program program = PLANNER.plan(ast); + + TestAllTypes result = (TestAllTypes) program.eval(); + + assertThat(result) + .isEqualTo(TestAllTypes.newBuilder().setSingleString("foo").setSingleBool(true).build()); } private CelAbstractSyntaxTree compile(String expression) throws Exception {