From 0295c5288cc60ed1e0055bcef8f73e8b9999241d Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 3 Mar 2026 14:38:20 -0800 Subject: [PATCH 01/88] Update error message. PiperOrigin-RevId: 878130486 --- eval/eval/regex_match_step_test.cc | 2 +- internal/re2_options.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/eval/eval/regex_match_step_test.cc b/eval/eval/regex_match_step_test.cc index 8d54a0188..53b955b25 100644 --- a/eval/eval/regex_match_step_test.cc +++ b/eval/eval/regex_match_step_test.cc @@ -94,7 +94,7 @@ TEST(RegexMatchStep, PrecompiledInvalidProgramTooLarge) { ASSERT_OK(RegisterBuiltinFunctions(expr_builder->GetRegistry(), options)); EXPECT_THAT(expr_builder->CreateExpression(&checked_expr), StatusIs(absl::StatusCode::kInvalidArgument, - Eq("regular expressions exceeds max allowed size"))); + Eq("regular expression exceeds max allowed size"))); } } // namespace diff --git a/internal/re2_options.h b/internal/re2_options.h index 9c20ceb63..25a30f6bd 100644 --- a/internal/re2_options.h +++ b/internal/re2_options.h @@ -45,13 +45,13 @@ inline absl::Status CheckRE2(const RE2& re, int max_program_size) { if (max_program_size > 0 && program_size > 0 && program_size > max_program_size) { return absl::InvalidArgumentError( - "regular expressions exceeds max allowed size"); + "regular expression exceeds max allowed size"); } int reverse_program_size = re.ReverseProgramSize(); if (max_program_size > 0 && reverse_program_size > 0 && reverse_program_size > max_program_size) { return absl::InvalidArgumentError( - "regular expressions exceeds max allowed size"); + "regular expression exceeds max allowed size"); } return absl::OkStatus(); } From 0fc3715279c337e77aabebdbb954c895054f3d09 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 6 Mar 2026 22:13:29 -0800 Subject: [PATCH 02/88] No public description PiperOrigin-RevId: 879974382 --- eval/public/value_export_util.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eval/public/value_export_util.cc b/eval/public/value_export_util.cc index edb6e83e0..bca8a8d65 100644 --- a/eval/public/value_export_util.cc +++ b/eval/public/value_export_util.cc @@ -67,8 +67,8 @@ absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value, break; } case CelValue::Type::kBytes: { - absl::Base64Escape(in_value.BytesOrDie().value(), - out_value->mutable_string_value()); + *out_value->mutable_string_value() = + absl::Base64Escape(in_value.BytesOrDie().value()); break; } case CelValue::Type::kDuration: { From dd1bceff1ec9eda8d2c6a7ebb2651d3a14e513d4 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Tue, 10 Mar 2026 15:06:43 -0700 Subject: [PATCH 03/88] Rewrite IdentStep as Constant when variable value is provided to TypeChecker PiperOrigin-RevId: 881641178 --- eval/compiler/qualified_reference_resolver.cc | 11 ++- .../qualified_reference_resolver_test.cc | 91 +++++++++++++++++++ 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 09950bfe8..67f86ebb6 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -135,8 +135,17 @@ class ReferenceResolver : public cel::AstRewriterBase { expr.mutable_const_expr().set_int64_value( reference->value().int64_value()); return true; + } else if (expr.has_ident_expr()) { + // "google.protobuf.NullValue.NULL_VALUE" is a special case: sometimes + // it is interpreted as null value and sometimes as an enum constant. + if (reference->value().has_null_value() && + expr.ident_expr().name() == + "google.protobuf.NullValue.NULL_VALUE") { + return false; + } + expr.set_const_expr(reference->value()); + return true; } else { - // No update if the constant reference isn't an int (an enum value). return false; } } diff --git a/eval/compiler/qualified_reference_resolver_test.cc b/eval/compiler/qualified_reference_resolver_test.cc index 0d710a465..3fa7fca21 100644 --- a/eval/compiler/qualified_reference_resolver_test.cc +++ b/eval/compiler/qualified_reference_resolver_test.cc @@ -45,6 +45,7 @@ namespace google::api::expr::runtime { namespace { +using ::absl_testing::IsOk; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::Ast; @@ -343,6 +344,60 @@ TEST(ResolveReferences, EnumConstReferenceUsedSelect) { })pb")); } +// foo && bar +constexpr char kConstReferenceExpr[] = R"( + id: 1 + call_expr { + function: "_&&_" + args { + id: 2 + ident_expr { + name: "foo" + } + } + args { + id: 5 + ident_expr { + name: "bar" + } + } + } +)"; + +TEST(ResolveReferences, ConstReferenceFolded) { + std::unique_ptr expr_ast = ParseTestProto(kConstReferenceExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + ASSERT_THAT(RegisterBuiltinFunctions(&func_registry), IsOk()); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[2].set_name("foo"); + expr_ast->mutable_reference_map()[2].mutable_value().set_bool_value(true); + expr_ast->mutable_reference_map()[5].set_name("bar"); + expr_ast->mutable_reference_map()[5].mutable_value().set_bool_value(false); + IssueCollector issues(RuntimeIssue::Severity::kError); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(true)); + + EXPECT_THAT(ExprToProtoOrDie(expr_ast->root_expr()), EqualsProto(R"pb( + id: 1 + call_expr { + function: "_&&_" + args { + id: 2 + const_expr { bool_value: true } + } + args { + id: 5 + const_expr { bool_value: false } + } + })pb")); +} + TEST(ResolveReferences, ConstReferenceSkipped) { std::unique_ptr expr_ast = ParseTestProto(kExpr); SourceInfo source_info; @@ -388,6 +443,42 @@ TEST(ResolveReferences, ConstReferenceSkipped) { })pb")); } +constexpr char kNullValueReferenceExpr[] = R"( + id: 1 + call_expr { + function: "_+_" + args { + id: 2 + ident_expr { + name: "google.protobuf.NullValue.NULL_VALUE" + } + } + args { + id: 5 + const_expr { int64_value: 1 } + } + } +)"; + +TEST(ResolveReferences, NullValueReferenceSkipped) { + std::unique_ptr expr_ast = ParseTestProto(kNullValueReferenceExpr); + SourceInfo source_info; + + CelFunctionRegistry func_registry; + ASSERT_THAT(RegisterBuiltinFunctions(&func_registry), IsOk()); + cel::TypeRegistry type_registry; + Resolver registry("", func_registry.InternalGetRegistry(), type_registry, + type_registry.GetComposedTypeProvider()); + expr_ast->mutable_reference_map()[2].set_name( + "google.protobuf.NullValue.NULL_VALUE"); + expr_ast->mutable_reference_map()[2].mutable_value().set_null_value(nullptr); + IssueCollector issues(RuntimeIssue::Severity::kError); + + auto result = ResolveReferences(registry, issues, *expr_ast); + + ASSERT_THAT(result, IsOkAndHolds(/*was_rewritten=*/false)); +} + constexpr char kExtensionAndExpr[] = R"( id: 1 call_expr { From 8be02b6f370d601a72e9c8bd64516168fc558da8 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 11 Mar 2026 12:15:14 -0700 Subject: [PATCH 04/88] Update checker to treat gp.NullValue as an int constant. Makes the C++ checker behave more consistent with cel-go and cel-java. Direct references to google.protobuf.NullValue should behave as an enum/int, but google.protobuf.Value{} with null_value alternative active behaves as CEL null. PiperOrigin-RevId: 882134119 --- checker/internal/type_checker_impl.cc | 18 +----------------- checker/internal/type_checker_impl_test.cc | 5 ----- checker/optional_test.cc | 11 ++++++----- checker/standard_library.cc | 6 ++---- checker/type_checker_builder_factory_test.cc | 2 +- common/type.cc | 4 +++- common/type_test.cc | 4 ++-- 7 files changed, 15 insertions(+), 35 deletions(-) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 14dce1647..1e9995b19 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -103,21 +103,6 @@ SourceLocation ComputeSourceLocation(const Ast& ast, int64_t expr_id) { return SourceLocation{line_idx + 1, rel_position}; } -// Special case for protobuf null fields. -bool IsPbNullFieldAssignable(const Type& value, const Type& field) { - if (field.IsNull()) { - return value.IsInt() || value.IsNull(); - } - - if (field.IsOptional() && value.IsOptional() && - field.AsOptional()->GetParameter().IsNull()) { - auto value_param = value.AsOptional()->GetParameter(); - return value_param.IsInt() || value_param.IsNull(); - } - - return false; -} - // Flatten the type to the AST type representation to remove any lifecycle // dependency between the type check environment and the AST. // @@ -421,8 +406,7 @@ class ResolveVisitor : public AstVisitorBase { if (field.optional()) { field_type = OptionalType(arena_, field_type); } - if (!inference_context_->IsAssignable(value_type, field_type) && - !IsPbNullFieldAssignable(value_type, field_type)) { + if (!inference_context_->IsAssignable(value_type, field_type)) { ReportIssue(TypeCheckIssue::CreateError( ComputeSourceLocation(*ast_, field.id()), absl::StrCat( diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index c36051376..6eccc3701 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -1966,11 +1966,6 @@ INSTANTIATE_TEST_SUITE_P( .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, - CheckedExprTestCase{ - .expr = "TestAllTypes{null_value: null}", - .expected_result_type = AstType( - MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), - }, // Legacy nullability behaviors. CheckedExprTestCase{ .expr = "TestAllTypes{single_duration: null}", diff --git a/checker/optional_test.cc b/checker/optional_test.cc index 28ae9a889..87c14f0cd 100644 --- a/checker/optional_test.cc +++ b/checker/optional_test.cc @@ -267,15 +267,16 @@ INSTANTIATE_TEST_SUITE_P( IsOptionalType(TypeSpec(PrimitiveType::kString))}, TestCase{"optional.of('abc').optFlatMap(x, optional.of(x + 'def'))", IsOptionalType(TypeSpec(PrimitiveType::kString))}, - // Legacy nullability behaviors. TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: " "optional.of(0)}", Eq(TypeSpec(MessageTypeSpec( "cel.expr.conformance.proto3.TestAllTypes")))}, - TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: null}", - Eq(TypeSpec(MessageTypeSpec( - "cel.expr.conformance.proto3.TestAllTypes")))}, - TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: " + // Legacy nullability behaviors. + TestCase{ + "cel.expr.conformance.proto3.TestAllTypes{?single_value: null}", + Eq(TypeSpec( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")))}, + TestCase{"cel.expr.conformance.proto3.TestAllTypes{?single_value: " "optional.of(null)}", Eq(TypeSpec(MessageTypeSpec( "cel.expr.conformance.proto3.TestAllTypes")))}, diff --git a/checker/standard_library.cc b/checker/standard_library.cc index 4cd9e9831..744a171ef 100644 --- a/checker/standard_library.cc +++ b/checker/standard_library.cc @@ -14,6 +14,7 @@ #include "checker/standard_library.h" +#include #include #include "absl/base/no_destructor.h" @@ -833,11 +834,8 @@ absl::Status AddTypeConstantVariables(TypeCheckerBuilder& builder) { absl::Status AddEnumConstants(TypeCheckerBuilder& builder) { VariableDecl pb_null; pb_null.set_name("google.protobuf.NullValue.NULL_VALUE"); - // TODO(uncreated-issue/74): This is interpreted as an enum (int) or null in - // different cases. We should add some additional spec tests to cover this and - // update the behavior to be consistent. pb_null.set_type(IntType()); - pb_null.set_value(Constant(nullptr)); + pb_null.set_value(Constant(int64_t{0})); CEL_RETURN_IF_ERROR(builder.AddVariable(std::move(pb_null))); return absl::OkStatus(); } diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc index d5cf47fee..a15d2e173 100644 --- a/checker/type_checker_builder_factory_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -464,7 +464,7 @@ TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationValue) { // Note: one of fields are all added with safe traversal, so // we lose the union discriminator information. R"cel( - null_value == null && + null_value == 0 && number_value == 0.0 && string_value == '' && list_value == [] && diff --git a/common/type.cc b/common/type.cc index 2b81e39f8..ce8c7a89a 100644 --- a/common/type.cc +++ b/common/type.cc @@ -75,7 +75,9 @@ Type Type::Message(const Descriptor* absl_nonnull descriptor) { Type Type::Enum(const google::protobuf::EnumDescriptor* absl_nonnull descriptor) { if (descriptor->full_name() == "google.protobuf.NullValue") { - return NullType(); + // Special case NullValue to prevent the emebedder providing a different + // descriptor for it and it leaking. + return IntType(); } return EnumType(descriptor); } diff --git a/common/type_test.cc b/common/type_test.cc index 119234fdc..2cebf27ba 100644 --- a/common/type_test.cc +++ b/common/type_test.cc @@ -45,7 +45,7 @@ TEST(Type, Enum) { EXPECT_EQ(Type::Enum( ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( "google.protobuf.NullValue"))), - NullType()); + IntType()); } TEST(Type, Field) { @@ -58,7 +58,7 @@ TEST(Type, Field) { BoolType()); EXPECT_EQ( Type::Field(ABSL_DIE_IF_NULL(descriptor->FindFieldByName("null_value"))), - NullType()); + IntType()); EXPECT_EQ(Type::Field( ABSL_DIE_IF_NULL(descriptor->FindFieldByName("single_int32"))), IntType()); From 5cc391e2ea5e3d2d563bf18b33fd411fb1f46d84 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 13 Mar 2026 16:31:23 -0700 Subject: [PATCH 05/88] internal change PiperOrigin-RevId: 883393454 --- runtime/BUILD | 3 --- 1 file changed, 3 deletions(-) diff --git a/runtime/BUILD b/runtime/BUILD index b58880146..776a8223d 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -344,14 +344,11 @@ cc_test( deps = [ ":activation", ":constant_folding", - ":function", - ":register_function_helper", ":runtime_builder", ":runtime_options", ":standard_runtime_builder_factory", "//base:function_adapter", "//common:function_descriptor", - "//common:kind", "//common:value", "//extensions/protobuf:runtime_adapter", "//internal:testing", From f72d56a653f93d474479108958c9d8a64a0d97fb Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 20 Mar 2026 18:17:21 -0700 Subject: [PATCH 06/88] Fix C++-17 compatibility issue for double(string) impl. StartIt, EndIt ctor for std::string_view is a C++ 20 feature. PiperOrigin-RevId: 887065454 --- runtime/standard/type_conversion_functions.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/standard/type_conversion_functions.cc b/runtime/standard/type_conversion_functions.cc index 50b6e28ea..76e95751b 100644 --- a/runtime/standard/type_conversion_functions.cc +++ b/runtime/standard/type_conversion_functions.cc @@ -69,7 +69,7 @@ Value FormatDouble(double v, const Function::InvokeContext& context) { return cel::ErrorValue(absl::InvalidArgumentError(absl::StrCat( "double format error: ", std::make_error_code(result.ec).message()))); } - absl::string_view out(buf, result.ptr); + absl::string_view out(buf, result.ptr - buf); return StringValue::From(out, arena); #endif } From 298ac3d30f42a67996ed36b2f3ce01fc6757faa9 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 23 Mar 2026 13:30:26 -0700 Subject: [PATCH 07/88] Simplify recursive planning limit checks. Instead of trying to wrap a recursive subprogram in a stack machine step, the planner fails if the limit is exceeded. This simplifies program planning and avoids high overhead in deep but unbalanced ASTs. PiperOrigin-RevId: 888254034 --- eval/compiler/BUILD | 1 + eval/compiler/flat_expr_builder.cc | 271 ++++++++---------- eval/public/cel_options.h | 16 +- .../expression_builder_benchmark_test.cc | 95 ++++-- runtime/runtime_options.h | 16 +- .../standard_runtime_builder_factory_test.cc | 92 +++++- 6 files changed, 297 insertions(+), 194 deletions(-) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index e82b0ce13..62b208772 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -33,6 +33,7 @@ cc_library( "//base:data", "//common:expr", "//common:native_type", + "//common:navigable_ast", "//common:value", "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index a0fd427bd..91822092c 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -109,6 +110,13 @@ constexpr absl::string_view kBlock = "cel.@block"; // Forward declare to resolve circular dependency for short_circuiting visitors. class FlatExprVisitor; +// Error code for failed recursive program building. Generally indicates an +// optimization doesn't support recursive programs. +absl::Status FailedRecursivePlanning() { + return absl::InternalError( + "failed to build recursive program. check for unsupported optimizations"); +} + // Helper for bookkeeping variables mapped to indexes. class IndexManager { public: @@ -577,6 +585,12 @@ class FlatExprVisitor : public cel::AstVisitor { } } + void SetMaxRecursionDepth(int max_recursion_depth) { + max_recursion_depth_ = max_recursion_depth; + } + + bool PlanRecursiveProgram() const { return max_recursion_depth_ > 0; } + void PreVisitExpr(const cel::Expr& expr) override { ValidateOrError(!absl::holds_alternative(expr.kind()), "Invalid empty expression"); @@ -947,8 +961,7 @@ class FlatExprVisitor : public cel::AstVisitor { return; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 1) { SetProgressStatusError(absl::InternalError( @@ -1064,21 +1077,13 @@ class FlatExprVisitor : public cel::AstVisitor { } } + // Returns the maximum recursion depth of the current program if it is + // eligible for recursion, or nullopt if it is not. absl::optional RecursionEligible() { - if (program_builder_.current() == nullptr) { + if (!PlanRecursiveProgram() || program_builder_.current() == nullptr) { return absl::nullopt; } - absl::optional depth = - program_builder_.current()->RecursiveDependencyDepth(); - if (!depth.has_value()) { - // one or more of the dependencies isn't eligible. - return depth; - } - if (options_.max_recursion_depth < 0 || - *depth < options_.max_recursion_depth) { - return depth; - } - return absl::nullopt; + return program_builder_.current()->RecursiveDependencyDepth(); } std::vector> @@ -1089,10 +1094,7 @@ class FlatExprVisitor : public cel::AstVisitor { return program_builder_.current()->ExtractRecursiveDependencies(); } - void MaybeMakeTernaryRecursive(const cel::Expr* expr) { - if (options_.max_recursion_depth == 0) { - return; - } + void MakeTernaryRecursive(const cel::Expr* expr) { if (expr->call_expr().args().size() != 3) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin ternary")); @@ -1107,26 +1109,16 @@ class FlatExprVisitor : public cel::AstVisitor { auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); - int max_depth = 0; - if (condition_plan == nullptr || !condition_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, condition_plan->recursive_program().depth); - - if (left_plan == nullptr || !left_plan->IsRecursive()) { + if (condition_plan == nullptr || !condition_plan->IsRecursive() || + left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - if (right_plan == nullptr || !right_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { - return; - } + int max_depth = std::max({0, condition_plan->recursive_program().depth, + left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); SetRecursiveStep( CreateDirectTernaryStep(condition_plan->ExtractRecursiveProgram().step, @@ -1136,10 +1128,7 @@ class FlatExprVisitor : public cel::AstVisitor { max_depth + 1); } - void MaybeMakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { - if (options_.max_recursion_depth == 0) { - return; - } + void MakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { if (expr->call_expr().args().size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( "unexpected number of args for builtin boolean operator &&/||")); @@ -1151,21 +1140,14 @@ class FlatExprVisitor : public cel::AstVisitor { auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); - int max_depth = 0; - if (left_plan == nullptr || !left_plan->IsRecursive()) { + if (left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - if (right_plan == nullptr || !right_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { - return; - } + int max_depth = std::max({0, left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); if (is_or) { SetRecursiveStep( @@ -1182,11 +1164,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } - void MaybeMakeOptionalShortcircuitRecursive(const cel::Expr* expr, - bool is_or_value) { - if (options_.max_recursion_depth == 0) { - return; - } + void MakeOptionalShortcircuit(const cel::Expr* expr, bool is_or_value) { if (!expr->call_expr().has_target() || expr->call_expr().args().size() != 1) { SetProgressStatusError(absl::InvalidArgumentError( @@ -1199,21 +1177,13 @@ class FlatExprVisitor : public cel::AstVisitor { auto* left_plan = program_builder_.GetSubexpression(left_expr); auto* right_plan = program_builder_.GetSubexpression(right_expr); - int max_depth = 0; - if (left_plan == nullptr || !left_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, left_plan->recursive_program().depth); - - if (right_plan == nullptr || !right_plan->IsRecursive()) { - return; - } - max_depth = std::max(max_depth, right_plan->recursive_program().depth); - - if (options_.max_recursion_depth >= 0 && - max_depth >= options_.max_recursion_depth) { + if (left_plan == nullptr || !left_plan->IsRecursive() || + right_plan == nullptr || !right_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } + int max_depth = std::max({0, left_plan->recursive_program().depth, + right_plan->recursive_program().depth}); SetRecursiveStep(CreateDirectOptionalOrStep( expr->id(), left_plan->ExtractRecursiveProgram().step, @@ -1225,7 +1195,7 @@ class FlatExprVisitor : public cel::AstVisitor { void MaybeMakeBindRecursive(const cel::Expr* expr, const cel::ComprehensionExpr* comprehension, size_t accu_slot) { - if (options_.max_recursion_depth == 0) { + if (!PlanRecursiveProgram()) { return; } @@ -1233,16 +1203,12 @@ class FlatExprVisitor : public cel::AstVisitor { program_builder_.GetSubexpression(&comprehension->result()); if (result_plan == nullptr || !result_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } int result_depth = result_plan->recursive_program().depth; - if (options_.max_recursion_depth > 0 && - result_depth >= options_.max_recursion_depth) { - return; - } - auto program = result_plan->ExtractRecursiveProgram(); SetRecursiveStep( CreateDirectBindStep(accu_slot, std::move(program.step), expr->id()), @@ -1252,42 +1218,26 @@ class FlatExprVisitor : public cel::AstVisitor { void MaybeMakeComprehensionRecursive( const cel::Expr* expr, const cel::ComprehensionExpr* comprehension, size_t iter_slot, size_t iter2_slot, size_t accu_slot) { - if (options_.max_recursion_depth == 0) { + if (!PlanRecursiveProgram()) { return; } auto* accu_plan = program_builder_.GetSubexpression(&comprehension->accu_init()); - - if (accu_plan == nullptr || !accu_plan->IsRecursive()) { - return; - } - auto* range_plan = program_builder_.GetSubexpression(&comprehension->iter_range()); - - if (range_plan == nullptr || !range_plan->IsRecursive()) { - return; - } - auto* loop_plan = program_builder_.GetSubexpression(&comprehension->loop_step()); - - if (loop_plan == nullptr || !loop_plan->IsRecursive()) { - return; - } - auto* condition_plan = program_builder_.GetSubexpression(&comprehension->loop_condition()); - - if (condition_plan == nullptr || !condition_plan->IsRecursive()) { - return; - } - auto* result_plan = program_builder_.GetSubexpression(&comprehension->result()); - - if (result_plan == nullptr || !result_plan->IsRecursive()) { + if (accu_plan == nullptr || !accu_plan->IsRecursive() || + range_plan == nullptr || !range_plan->IsRecursive() || + loop_plan == nullptr || !loop_plan->IsRecursive() || + condition_plan == nullptr || !condition_plan->IsRecursive() || + result_plan == nullptr || !result_plan->IsRecursive()) { + SetProgressStatusError(FailedRecursivePlanning()); return; } @@ -1298,11 +1248,6 @@ class FlatExprVisitor : public cel::AstVisitor { max_depth = std::max(max_depth, condition_plan->recursive_program().depth); max_depth = std::max(max_depth, result_plan->recursive_program().depth); - if (options_.max_recursion_depth > 0 && - max_depth >= options_.max_recursion_depth) { - return; - } - auto step = CreateDirectComprehensionStep( iter_slot, iter2_slot, accu_slot, range_plan->ExtractRecursiveProgram().step, @@ -1566,7 +1511,7 @@ class FlatExprVisitor : public cel::AstVisitor { comprehension_stack_.back(); if (comprehension.is_optimizable_list_append) { if (&(comprehension.comprehension->accu_init()) == &expr) { - if (options_.max_recursion_depth != 0) { + if (PlanRecursiveProgram()) { SetRecursiveStep(CreateDirectMutableListStep(expr.id()), 1); return; } @@ -1579,8 +1524,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } } - absl::optional depth = RecursionEligible(); - if (depth.has_value()) { + if (absl::optional depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != list_expr.elements().size()) { SetProgressStatusError(absl::InternalError( @@ -1614,8 +1558,7 @@ class FlatExprVisitor : public cel::AstVisitor { std::vector fields = std::move(status_or_resolved_fields.value().second); - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != struct_expr.fields().size()) { SetProgressStatusError(absl::InternalError( @@ -1646,7 +1589,7 @@ class FlatExprVisitor : public cel::AstVisitor { comprehension_stack_.back(); if (comprehension.is_optimizable_map_insert) { if (&(comprehension.comprehension->accu_init()) == &expr) { - if (options_.max_recursion_depth != 0) { + if (PlanRecursiveProgram()) { SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1); return; } @@ -1656,8 +1599,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 2 * map_expr.entries().size()) { SetProgressStatusError(absl::InternalError( @@ -1696,8 +1638,7 @@ class FlatExprVisitor : public cel::AstVisitor { auto lazy_overloads = resolver_.FindLazyOverloads( function, call_expr->has_target(), num_args, expr->id()); if (!lazy_overloads.empty()) { - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = program_builder_.current()->ExtractRecursiveDependencies(); SetRecursiveStep(CreateDirectLazyFunctionStep( expr->id(), *call_expr, std::move(args), @@ -1727,8 +1668,9 @@ class FlatExprVisitor : public cel::AstVisitor { return; } } - auto recursion_depth = RecursionEligible(); - if (recursion_depth.has_value()) { + + if (auto recursion_depth = RecursionEligible(); + recursion_depth.has_value()) { // Nonnull while active -- nullptr indicates logic error elsewhere in the // builder. ABSL_DCHECK(program_builder_.current() != nullptr); @@ -1777,6 +1719,11 @@ class FlatExprVisitor : public cel::AstVisitor { return; } program_builder_.current()->set_recursive_program(std::move(step), depth); + if (depth > max_recursion_depth_) { + SetProgressStatusError(absl::InvalidArgumentError( + absl::StrCat("Maximum recursion depth of ", + options_.max_recursion_depth, " exceeded"))); + } } void SetProgressStatusError(const absl::Status& status) { @@ -1980,17 +1927,17 @@ class FlatExprVisitor : public cel::AstVisitor { IssueCollector& issue_collector_; ProgramBuilder& program_builder_; - PlannerContext extension_context_; + PlannerContext& extension_context_; IndexManager index_manager_; bool enable_optional_types_; absl::optional block_; + int max_recursion_depth_ = 0; }; FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( const cel::Expr& expr, const cel::CallExpr& call_expr) { ABSL_DCHECK(call_expr.function() == cel::builtin::kIndex); - auto depth = RecursionEligible(); if (!ValidateOrError( (call_expr.args().size() == 2 && !call_expr.has_target()) || // TODO(uncreated-issue/79): A few clients use the index operator with a @@ -2000,7 +1947,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( return CallHandlerResult::kIntercepted; } - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2027,9 +1974,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2046,15 +1991,13 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNotStrictlyFalse( const cel::Expr& expr, const cel::CallExpr& call_expr) { - auto depth = RecursionEligible(); - if (!ValidateOrError(call_expr.args().size() == 1 && !call_expr.has_target(), "unexpected number of args for builtin " "not_strictly_false operator")) { return CallHandlerResult::kIntercepted; } - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { SetProgressStatusError( @@ -2155,9 +2098,8 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleHeterogeneousEquality( "unexpected number of args for builtin equality operator")) { return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2182,8 +2124,7 @@ FlatExprVisitor::HandleHeterogeneousEqualityIn(const cel::Expr& expr, return CallHandlerResult::kIntercepted; } - auto depth = RecursionEligible(); - if (depth.has_value()) { + if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { SetProgressStatusError(absl::InvalidArgumentError( @@ -2221,6 +2162,9 @@ void BinaryCondVisitor::PreVisit(const cel::Expr* expr) { } void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } if (short_circuiting_ && arg_num == 0 && (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) { // If first branch evaluation result is enough to determine output, @@ -2248,6 +2192,9 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { } void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } if (short_circuiting_ && (cond_ == BinaryCond::kOptionalOr || cond_ == BinaryCond::kOptionalOrValue)) { // If first branch evaluation result is enough to determine output, @@ -2275,6 +2222,28 @@ void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { } void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + switch (cond_) { + case BinaryCond::kAnd: + visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/false); + break; + case BinaryCond::kOr: + visitor_->MakeShortcircuitRecursive(expr, /*is_or=*/true); + break; + case BinaryCond::kOptionalOr: + visitor_->MakeOptionalShortcircuit(expr, + /*is_or_value=*/false); + break; + case BinaryCond::kOptionalOrValue: + visitor_->MakeOptionalShortcircuit(expr, + /*is_or_value=*/true); + break; + default: + ABSL_UNREACHABLE(); + } + return; + } + switch (cond_) { case BinaryCond::kAnd: visitor_->AddStep(CreateAndStep(expr->id())); @@ -2298,26 +2267,6 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { visitor_->SetProgressStatusError( jump_step_.set_target(visitor_->GetCurrentIndex())); } - // Handle maybe replacing the subprogram with a recursive version. This needs - // to happen after the jump step is updated (though it may get overwritten). - switch (cond_) { - case BinaryCond::kAnd: - visitor_->MaybeMakeShortcircuitRecursive(expr, /*is_or=*/false); - break; - case BinaryCond::kOr: - visitor_->MaybeMakeShortcircuitRecursive(expr, /*is_or=*/true); - break; - case BinaryCond::kOptionalOr: - visitor_->MaybeMakeOptionalShortcircuitRecursive(expr, - /*is_or_value=*/false); - break; - case BinaryCond::kOptionalOrValue: - visitor_->MaybeMakeOptionalShortcircuitRecursive(expr, - /*is_or_value=*/true); - break; - default: - ABSL_UNREACHABLE(); - } } void TernaryCondVisitor::PreVisit(const cel::Expr* expr) { @@ -2327,6 +2276,9 @@ void TernaryCondVisitor::PreVisit(const cel::Expr* expr) { } void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } // Ternary operator "_?_:_" requires a special handing. // In contrary to regular function call, its execution affects the control // flow of the overall CEL expression. @@ -2380,6 +2332,10 @@ void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { } void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + visitor_->MakeTernaryRecursive(expr); + return; + } // Determine and set jump offset in jump instruction. if (visitor_->ValidateOrError( error_jump_.exists(), @@ -2393,7 +2349,6 @@ void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { visitor_->SetProgressStatusError( jump_after_first_.set_target(visitor_->GetCurrentIndex())); } - visitor_->MaybeMakeTernaryRecursive(expr); } void ExhaustiveTernaryCondVisitor::PreVisit(const cel::Expr* expr) { @@ -2403,8 +2358,11 @@ void ExhaustiveTernaryCondVisitor::PreVisit(const cel::Expr* expr) { } void ExhaustiveTernaryCondVisitor::PostVisit(const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + visitor_->MakeTernaryRecursive(expr); + return; + } visitor_->AddStep(CreateTernaryStep(expr->id())); - visitor_->MaybeMakeTernaryRecursive(expr); } void ComprehensionVisitor::PreVisit(const cel::Expr* expr) { @@ -2417,6 +2375,9 @@ void ComprehensionVisitor::PreVisit(const cel::Expr* expr) { absl::Status ComprehensionVisitor::PostVisitArgDefault( cel::ComprehensionArg arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return absl::OkStatus(); + } switch (arg_num) { case cel::ITER_RANGE: { init_step_pos_ = visitor_->GetCurrentIndex(); @@ -2491,6 +2452,9 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault( void ComprehensionVisitor::PostVisitArgTrivial(cel::ComprehensionArg arg_num, const cel::Expr* expr) { + if (visitor_->PlanRecursiveProgram()) { + return; + } switch (arg_num) { case cel::ITER_RANGE: { break; @@ -2590,6 +2554,13 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( issue_collector, program_builder, extension_context, enable_optional_types_); + if (options_.max_recursion_depth == -1 || options_.max_recursion_depth > 0) { + int depth_limit = options_.max_recursion_depth == -1 + ? std::numeric_limits::max() + : options_.max_recursion_depth; + visitor.SetMaxRecursionDepth(depth_limit); + } + cel::TraversalOptions opts; opts.use_comprehension_callbacks = true; AstTraverse(ast->root_expr(), visitor, opts); diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index 779839583..4d81eb8a7 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -171,17 +171,23 @@ struct InterpreterOptions { // removed in a later update. bool enable_lazy_bind_initialization = true; - // Maximum recursion depth for evaluable programs. + // Enable recursive planning with a maximum recursion depth for evaluable + // programs. // - // This is proportional to the maximum number of recursive Evaluate calls that - // a single expression program might require while evaluating. This is - // coarse -- the actual C++ stack requirements will vary depending on the + // This limit is proportional to the maximum number of recursive Evaluate + // calls that a single expression program might require while evaluating. This + // is coarse -- the actual C++ stack requirements will vary depending on the // expression. // // This does not account for re-entrant evaluation in a client's extension - // function. + // function (i.e. a CEL function that calls Evaluate on another CEL program) + // + // If the limit is exceeded, the planner will return an error instead of + // planning the program. // // -1 means unbounded. + // 0 means disabled (using a heap-based stack machine instead), which is the + // default. int max_recursion_depth = 0; // Enable tracing support for recursively planned programs. diff --git a/eval/tests/expression_builder_benchmark_test.cc b/eval/tests/expression_builder_benchmark_test.cc index c26a7cd5c..410df8902 100644 --- a/eval/tests/expression_builder_benchmark_test.cc +++ b/eval/tests/expression_builder_benchmark_test.cc @@ -1,18 +1,16 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -50,8 +48,24 @@ using google::api::expr::parser::Parse; enum BenchmarkParam : int { kDefault = 0, kFoldConstants = 1, + kRecursivePlanning = 2, + kRecursivePlanningWithConstantFolding = 3, }; +std::string LabelForParam(BenchmarkParam param) { + switch (param) { + case BenchmarkParam::kDefault: + return "default"; + case BenchmarkParam::kFoldConstants: + return "fold_constants"; + case BenchmarkParam::kRecursivePlanning: + return "recursive_planning"; + case BenchmarkParam::kRecursivePlanningWithConstantFolding: + return "recursive_planning_with_constant_folding"; + } + return "unknown"; +} + void BM_RegisterBuiltins(benchmark::State& state) { for (auto _ : state) { auto builder = CreateCelExpressionBuilder(); @@ -64,21 +78,33 @@ BENCHMARK(BM_RegisterBuiltins); InterpreterOptions OptionsForParam(BenchmarkParam param, google::protobuf::Arena& arena) { InterpreterOptions options; - switch (param) { case BenchmarkParam::kFoldConstants: + case BenchmarkParam::kRecursivePlanningWithConstantFolding: options.constant_arena = &arena; options.constant_folding = true; break; case BenchmarkParam::kDefault: + case BenchmarkParam::kRecursivePlanning: options.constant_folding = false; break; } + switch (param) { + case BenchmarkParam::kRecursivePlanning: + case BenchmarkParam::kRecursivePlanningWithConstantFolding: + options.max_recursion_depth = 48; + break; + case BenchmarkParam::kDefault: + case BenchmarkParam::kFoldConstants: + options.max_recursion_depth = 0; + break; + } return options; } void BM_SymbolicPolicy(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( !(request.ip in ["10.0.1.4", "10.0.1.5", "10.0.1.6"]) && @@ -105,7 +131,9 @@ void BM_SymbolicPolicy(benchmark::State& state) { BENCHMARK(BM_SymbolicPolicy) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); absl::StatusOr> MakeBuilderForEnums( absl::string_view container, absl::string_view enum_type, @@ -209,6 +237,7 @@ BENCHMARK(BM_EnumResolution256Candidate)->ThreadRange(1, 32); void BM_NestedComprehension(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( [4, 5, 6].all(x, [1, 2, 3].all(y, x > y) && [7, 8, 9].all(z, x < z)) @@ -231,10 +260,13 @@ void BM_NestedComprehension(benchmark::State& state) { BENCHMARK(BM_NestedComprehension) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_Comparisons(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( v11 < v12 && v12 < v13 @@ -260,7 +292,9 @@ void BM_Comparisons(benchmark::State& state) { BENCHMARK(BM_Comparisons) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_ComparisonsConcurrent(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"( @@ -290,6 +324,8 @@ BENCHMARK(BM_ComparisonsConcurrent)->ThreadRange(1, 32); void RegexPrecompilationBench(bool enabled, benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(absl::StrCat(LabelForParam(param), "_", + enabled ? "enabled" : "disabled")); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, parser::Parse(R"cel( input_str.matches(r'192\.168\.' + '[0-9]{1,3}' + r'\.' + '[0-9]{1,3}') || @@ -325,7 +361,9 @@ void BM_RegexPrecompilationDisabled(benchmark::State& state) { BENCHMARK(BM_RegexPrecompilationDisabled) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_RegexPrecompilationEnabled(benchmark::State& state) { RegexPrecompilationBench(true, state); @@ -333,10 +371,13 @@ void BM_RegexPrecompilationEnabled(benchmark::State& state) { BENCHMARK(BM_RegexPrecompilationEnabled) ->Arg(BenchmarkParam::kDefault) - ->Arg(BenchmarkParam::kFoldConstants); + ->Arg(BenchmarkParam::kFoldConstants) + ->Arg(BenchmarkParam::kRecursivePlanning) + ->Arg(BenchmarkParam::kRecursivePlanningWithConstantFolding); void BM_StringConcat(benchmark::State& state) { auto param = static_cast(state.range(0)); + state.SetLabel(LabelForParam(param)); auto size = state.range(1); std::string source = "'1234567890' + '1234567890'"; @@ -377,7 +418,17 @@ BENCHMARK(BM_StringConcat) ->Args({BenchmarkParam::kFoldConstants, 4}) ->Args({BenchmarkParam::kFoldConstants, 8}) ->Args({BenchmarkParam::kFoldConstants, 16}) - ->Args({BenchmarkParam::kFoldConstants, 32}); + ->Args({BenchmarkParam::kFoldConstants, 32}) + ->Args({BenchmarkParam::kRecursivePlanning, 2}) + ->Args({BenchmarkParam::kRecursivePlanning, 4}) + ->Args({BenchmarkParam::kRecursivePlanning, 8}) + ->Args({BenchmarkParam::kRecursivePlanning, 16}) + ->Args({BenchmarkParam::kRecursivePlanning, 32}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 2}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 4}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 8}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 16}) + ->Args({BenchmarkParam::kRecursivePlanningWithConstantFolding, 32}); void BM_StringConcat32Concurrent(benchmark::State& state) { std::string source = "'1234567890' + '1234567890'"; diff --git a/runtime/runtime_options.h b/runtime/runtime_options.h index 1e18fef95..7a61208a0 100644 --- a/runtime/runtime_options.h +++ b/runtime/runtime_options.h @@ -139,17 +139,23 @@ struct RuntimeOptions { // removed in a later update. bool enable_lazy_bind_initialization = true; - // Maximum recursion depth for evaluable programs. + // Enable recursive planning with a maximum recursion depth for evaluable + // programs. // - // This is proportional to the maximum number of recursive Evaluate calls that - // a single expression program might require while evaluating. This is - // coarse -- the actual C++ stack requirements will vary depending on the + // This limit is proportional to the maximum number of recursive Evaluate + // calls that a single expression program might require while evaluating. This + // is coarse -- the actual C++ stack requirements will vary depending on the // expression. // // This does not account for re-entrant evaluation in a client's extension - // function. + // function (i.e. a CEL function that calls Evaluate on another CEL program) + // + // If the limit is exceeded, the planner will return an error instead of + // planning the program. // // -1 means unbounded. + // 0 means disabled (using a heap-based stack machine instead), which is the + // default. int max_recursion_depth = 0; // Enable tracing support for recursively planned programs. diff --git a/runtime/standard_runtime_builder_factory_test.cc b/runtime/standard_runtime_builder_factory_test.cc index b73085f3c..029897233 100644 --- a/runtime/standard_runtime_builder_factory_test.cc +++ b/runtime/standard_runtime_builder_factory_test.cc @@ -52,24 +52,14 @@ using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::extensions::ProtobufRuntimeAdapter; using ::cel::test::BoolValueIs; +using ::cel::test::IntValueIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::testing::ElementsAre; +using ::testing::HasSubstr; using ::testing::TestWithParam; using ::testing::Truly; -struct EvaluateResultTestCase { - std::string name; - std::string expression; - bool expected_result; - std::function activation_builder; - - template - friend void AbslStringify(S& sink, const EvaluateResultTestCase& tc) { - sink.Append(tc.name); - } -}; - const cel::MacroRegistry& GetMacros() { static absl::NoDestructor macros([]() { MacroRegistry registry; @@ -88,6 +78,84 @@ absl::StatusOr ParseWithTestMacros(absl::string_view expression) { return Parse(**src, GetMacros()); } +TEST(StandardRuntimeTest, RecursionLimitExceeded) { + RuntimeOptions opts; + opts.max_recursion_depth = 1; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros("1 + 2")); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Maximum recursion depth of 1 exceeded"))); +} + +TEST(StandardRuntimeTest, RecursionUnderLimit) { + RuntimeOptions opts; + opts.max_recursion_depth = 2; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros("1 + 2")); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + + // Whether the implementation is recursive shouldn't affect observable + // behavior, but it does have performance implications (it will skip + // allocating a value stack). + EXPECT_TRUE(runtime_internal::TestOnly_IsRecursiveImpl(program.get())); + + google::protobuf::Arena arena; + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + EXPECT_THAT(result, IntValueIs(3)); +} + +TEST(StandardRuntimeTest, RecursionLimitTracksLazyExpressions) { + RuntimeOptions opts; + opts.max_recursion_depth = 8; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + google::protobuf::DescriptorPool::generated_pool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, ParseWithTestMacros(R"cel( + cel.bind(a, 4 + (3 + (2 + 1)), + cel.bind(b, 7 + (6 + (5 + a)), + 9 + (8 + b) + ) + ))cel")); + + EXPECT_THAT(ProtobufRuntimeAdapter::CreateProgram(*runtime, expr), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Maximum recursion depth of 8 exceeded"))); +} + +struct EvaluateResultTestCase { + std::string name; + std::string expression; + bool expected_result; + std::function activation_builder; + + template + friend void AbslStringify(S& sink, const EvaluateResultTestCase& tc) { + sink.Append(tc.name); + } +}; + class StandardRuntimeTest : public TestWithParam { public: const EvaluateResultTestCase& GetTestCase() { return GetParam(); } From 993bc2d65c215070413f564605d7b29b4cfb51ec Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 23 Mar 2026 15:54:42 -0700 Subject: [PATCH 08/88] Limit precision to <1000 in (string).format in the strings extension. PiperOrigin-RevId: 888321829 --- extensions/formatting.cc | 5 ++++ extensions/formatting_test.cc | 49 +++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/extensions/formatting.cc b/extensions/formatting.cc index 970cc6388..6e58a7b86 100644 --- a/extensions/formatting.cc +++ b/extensions/formatting.cc @@ -54,6 +54,7 @@ namespace { static constexpr int32_t kNanosPerMillisecond = 1000000; static constexpr int32_t kNanosPerMicrosecond = 1000; +static constexpr int kMaxPrecision = 1000; absl::StatusOr FormatString( const Value& value, @@ -79,6 +80,10 @@ absl::StatusOr>> ParsePrecision( return absl::InvalidArgumentError( "unable to convert precision specifier to integer"); } + if (precision > kMaxPrecision) { + return absl::InvalidArgumentError( + absl::StrCat("precision specifier exceeds maximum of ", kMaxPrecision)); + } return std::pair{i, precision}; } diff --git a/extensions/formatting_test.cc b/extensions/formatting_test.cc index 433e4ae24..824f14e45 100644 --- a/extensions/formatting_test.cc +++ b/extensions/formatting_test.cc @@ -59,6 +59,49 @@ using ::testing::HasSubstr; using ::testing::TestWithParam; using ::testing::ValuesIn; +using StringFormatLimitsTest = TestWithParam; + +// Check that formatted floating points are reversible. +TEST_P(StringFormatLimitsTest, FormatLimits) { + google::protobuf::Arena arena; + const RuntimeOptions options; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT( + RegisterStringFormattingFunctions(builder.function_registry(), options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, + Parse(GetParam(), "", ParserOptions{})); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + Activation activation; + + static_assert(std::numeric_limits::min_exponent == -1021); + for (double x : { + 0x1p-1021, + 0x3p-1021, + std::numeric_limits::epsilon() * 0x1p-3, + std::numeric_limits::epsilon() * 0x7p-3, + 1.1 / 7.0 * 1e-101, + 1.2 / 7.0 * 1e-101, + }) { + activation.InsertOrAssignValue("x", DoubleValue(x)); + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + ASSERT_TRUE(value.Is()); + EXPECT_TRUE(value.GetBool().NativeValue()); + } +} + +INSTANTIATE_TEST_SUITE_P(StringFormatLimitsTest, StringFormatLimitsTest, + ValuesIn({ + "double('%.326f'.format([x])) == x", + "double('%.17e'.format([x])) == x", + })); + struct FormattingTestCase { std::string name; std::string format; @@ -207,6 +250,12 @@ INSTANTIATE_TEST_SUITE_P( .format_args = "'hello'", .error = "unable to find end of precision specifier", }, + { + .name = "InvalidPrecisionOutOfRange", + .format = "%.1001f", + .format_args = "1.2345", + .error = "precision specifier exceeds maximum of 100", + }, { .name = "DecimalFormatingClause", .format = "int %d, uint %d", From 614b79f15c94beca72d98fc9852fbf8b3baa3e39 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 23 Mar 2026 15:59:33 -0700 Subject: [PATCH 09/88] Introduce EnvRuntime and config-driven standard extension functions PiperOrigin-RevId: 888323645 --- extensions/BUILD | 1 + extensions/math_ext.cc | 29 +++++++++++++++++++---------- extensions/math_ext.h | 6 ++++-- extensions/strings.cc | 37 +++++++++++++++++++++++++------------ extensions/strings.h | 5 +++-- 5 files changed, 52 insertions(+), 26 deletions(-) diff --git a/extensions/BUILD b/extensions/BUILD index 1e6e9204a..fe97af46a 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -75,6 +75,7 @@ cc_library( srcs = ["math_ext.cc"], hdrs = ["math_ext.h"], deps = [ + ":math_ext_decls", "//common:casting", "//common:value", "//eval/public:cel_function_registry", diff --git a/extensions/math_ext.cc b/extensions/math_ext.cc index 7b3655de3..4d133d90c 100644 --- a/extensions/math_ext.cc +++ b/extensions/math_ext.cc @@ -308,7 +308,8 @@ Value BitShiftRightUint(uint64_t lhs, int64_t rhs) { } // namespace absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, - const RuntimeOptions& options) { + const RuntimeOptions& options, + int version) { CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( kMathMin, Identity, registry))); @@ -360,6 +361,9 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, UnaryFunctionAdapter, ListValue>::RegisterGlobalOverload(kMathMax, MaxList, registry))); + if (version == 0) { + return absl::OkStatus(); + } CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( @@ -370,15 +374,6 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.round", RoundDouble, registry))); - CEL_RETURN_IF_ERROR( - (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtDouble, registry))); - CEL_RETURN_IF_ERROR( - (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtInt, registry))); - CEL_RETURN_IF_ERROR( - (UnaryFunctionAdapter::RegisterGlobalOverload( - "math.sqrt", SqrtUint, registry))); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "math.trunc", TruncDouble, registry))); @@ -453,6 +448,20 @@ absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, (BinaryFunctionAdapter::RegisterGlobalOverload( "math.bitShiftRight", BitShiftRightUint, registry))); + if (version == 1) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtDouble, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtInt, registry))); + CEL_RETURN_IF_ERROR( + (UnaryFunctionAdapter::RegisterGlobalOverload( + "math.sqrt", SqrtUint, registry))); + return absl::OkStatus(); } diff --git a/extensions/math_ext.h b/extensions/math_ext.h index 63d9e964b..fe000e476 100644 --- a/extensions/math_ext.h +++ b/extensions/math_ext.h @@ -18,6 +18,7 @@ #include "absl/status/status.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" +#include "extensions/math_ext_decls.h" #include "runtime/function_registry.h" #include "runtime/runtime_options.h" @@ -25,8 +26,9 @@ namespace cel::extensions { // Register extension functions for supporting mathematical operations above // and beyond the set defined in the CEL standard environment. -absl::Status RegisterMathExtensionFunctions(FunctionRegistry& registry, - const RuntimeOptions& options); +absl::Status RegisterMathExtensionFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + int version = kMathExtensionLatestVersion); absl::Status RegisterMathExtensionFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, diff --git a/extensions/strings.cc b/extensions/strings.cc index 652c72572..ed6f27319 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -306,17 +306,8 @@ absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder, int version) { } // namespace absl::Status RegisterStringsFunctions(FunctionRegistry& registry, - const RuntimeOptions& options) { - CEL_RETURN_IF_ERROR(registry.Register( - UnaryFunctionAdapter, ListValue>::CreateDescriptor( - "join", /*receiver_style=*/true), - UnaryFunctionAdapter, ListValue>::WrapFunction( - Join1))); - CEL_RETURN_IF_ERROR(registry.Register( - BinaryFunctionAdapter, ListValue, StringValue>:: - CreateDescriptor("join", /*receiver_style=*/true), - BinaryFunctionAdapter, ListValue, - StringValue>::WrapFunction(Join2))); + const RuntimeOptions& options, + int version) { CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, StringValue>:: CreateDescriptor("split", /*receiver_style=*/true), @@ -350,7 +341,6 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, int64_t>::CreateDescriptor("replace", /*receiver_style=*/true), QuaternaryFunctionAdapter, StringValue, StringValue, StringValue, int64_t>::WrapFunction(Replace2))); - CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions(registry, options)); CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter::RegisterMemberOverload("charAt", &CharAt, @@ -388,9 +378,32 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterMemberOverload( "trim", &Trim, registry))); + if (version == 0) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions(registry, options)); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "strings.quote", &Quote, registry))); + if (version == 1) { + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(registry.Register( + UnaryFunctionAdapter, ListValue>::CreateDescriptor( + "join", /*receiver_style=*/true), + UnaryFunctionAdapter, ListValue>::WrapFunction( + Join1))); + CEL_RETURN_IF_ERROR(registry.Register( + BinaryFunctionAdapter, ListValue, StringValue>:: + CreateDescriptor("join", /*receiver_style=*/true), + BinaryFunctionAdapter, ListValue, + StringValue>::WrapFunction(Join2))); + if (version == 2) { + return absl::OkStatus(); + } + CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterMemberOverload( "reverse", &Reverse, registry))); diff --git a/extensions/strings.h b/extensions/strings.h index 5dab33c5d..3cbc9f19f 100644 --- a/extensions/strings.h +++ b/extensions/strings.h @@ -28,8 +28,9 @@ namespace cel::extensions { constexpr int kStringsExtensionLatestVersion = 4; // Register extension functions for strings. -absl::Status RegisterStringsFunctions(FunctionRegistry& registry, - const RuntimeOptions& options); +absl::Status RegisterStringsFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + int version = kStringsExtensionLatestVersion); absl::Status RegisterStringsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, From 1792682e83b401dced56213229cd4f4df68afdbd Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Tue, 24 Mar 2026 10:26:16 -0700 Subject: [PATCH 10/88] Publish cel/cpp/env to GitHub PiperOrigin-RevId: 888737100 --- MODULE.bazel | 5 + env/BUILD | 290 ++++ env/config.cc | 190 +++ env/config.h | 160 +++ env/config_test.cc | 222 ++++ env/env.cc | 318 +++++ env/env.h | 71 + env/env_runtime.cc | 77 ++ env/env_runtime.h | 73 + env/env_runtime_test.cc | 160 +++ env/env_std_extensions.cc | 76 ++ env/env_std_extensions.h | 42 + env/env_std_extensions_test.cc | 116 ++ env/env_test.cc | 713 ++++++++++ env/env_yaml.cc | 1026 ++++++++++++++ env/env_yaml.h | 39 + env/env_yaml_test.cc | 1467 +++++++++++++++++++++ env/internal/BUILD | 87 ++ env/internal/ext_registry.cc | 63 + env/internal/ext_registry.h | 74 ++ env/internal/ext_registry_test.cc | 73 + env/internal/runtime_ext_registry.cc | 64 + env/internal/runtime_ext_registry.h | 84 ++ env/internal/runtime_ext_registry_test.cc | 126 ++ env/runtime_std_extensions.cc | 130 ++ env/runtime_std_extensions.h | 46 + env/runtime_std_extensions_test.cc | 229 ++++ 27 files changed, 6021 insertions(+) create mode 100644 env/BUILD create mode 100644 env/config.cc create mode 100644 env/config.h create mode 100644 env/config_test.cc create mode 100644 env/env.cc create mode 100644 env/env.h create mode 100644 env/env_runtime.cc create mode 100644 env/env_runtime.h create mode 100644 env/env_runtime_test.cc create mode 100644 env/env_std_extensions.cc create mode 100644 env/env_std_extensions.h create mode 100644 env/env_std_extensions_test.cc create mode 100644 env/env_test.cc create mode 100644 env/env_yaml.cc create mode 100644 env/env_yaml.h create mode 100644 env/env_yaml_test.cc create mode 100644 env/internal/BUILD create mode 100644 env/internal/ext_registry.cc create mode 100644 env/internal/ext_registry.h create mode 100644 env/internal/ext_registry_test.cc create mode 100644 env/internal/runtime_ext_registry.cc create mode 100644 env/internal/runtime_ext_registry.h create mode 100644 env/internal/runtime_ext_registry_test.cc create mode 100644 env/runtime_std_extensions.cc create mode 100644 env/runtime_std_extensions.h create mode 100644 env/runtime_std_extensions_test.cc diff --git a/MODULE.bazel b/MODULE.bazel index fbe9b41fc..02404b645 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -100,3 +100,8 @@ http_jar( sha256 = "eae2dfa119a64327444672aff63e9ec35a20180dc5b8090b7a6ab85125df4d76", urls = ["https://www.antlr.org/download/antlr-" + ANTLR4_VERSION + "-complete.jar"], ) + +bazel_dep( + name = "yaml-cpp", + version = "0.9.0", +) diff --git a/env/BUILD b/env/BUILD new file mode 100644 index 000000000..f5ce35557 --- /dev/null +++ b/env/BUILD @@ -0,0 +1,290 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + deps = [ + "//common:constant", + "//internal:status_macros", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "env", + srcs = ["env.cc"], + hdrs = ["env.h"], + deps = [ + ":config", + "//checker:type_checker_builder", + "//common:constant", + "//common:decl", + "//common:type", + "//common:type_kind", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//env/internal:ext_registry", + "//internal:status_macros", + "//parser:macro", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "env_runtime", + srcs = ["env_runtime.cc"], + hdrs = ["env_runtime.h"], + deps = [ + ":config", + "//env/internal:runtime_ext_registry", + "//internal:status_macros", + "//runtime", + "//runtime:runtime_builder", + "//runtime:runtime_builder_factory", + "//runtime:runtime_options", + "//runtime:standard_functions", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "env_std_extensions", + srcs = ["env_std_extensions.cc"], + hdrs = ["env_std_extensions.h"], + deps = [ + ":env", + "//checker:optional", + "//compiler:optional", + "//extensions:bindings_ext", + "//extensions:comprehensions_v2", + "//extensions:encoders", + "//extensions:lists_functions", + "//extensions:math_ext_decls", + "//extensions:proto_ext", + "//extensions:regex_ext", + "//extensions:sets_functions", + "//extensions:strings", + ], +) + +cc_library( + name = "env_yaml", + srcs = ["env_yaml.cc"], + hdrs = ["env_yaml.h"], + copts = [ + "-fexceptions", + ], + features = ["-use_header_modules"], + deps = [ + ":config", + "//common:constant", + "//internal:status_macros", + "//internal:strings", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@yaml-cpp", + ], +) + +cc_library( + name = "runtime_std_extensions", + srcs = ["runtime_std_extensions.cc"], + hdrs = ["runtime_std_extensions.h"], + deps = [ + ":env_runtime", + "//checker:optional", + "//env/internal:runtime_ext_registry", + "//extensions:encoders", + "//extensions:lists_functions", + "//extensions:math_ext", + "//extensions:math_ext_decls", + "//extensions:regex_ext", + "//extensions:sets_functions", + "//extensions:strings", + "//runtime:optional_types", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "config_test", + srcs = ["config_test.cc"], + deps = [ + ":config", + "//common:constant", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) + +cc_test( + name = "env_test", + srcs = ["env_test.cc"], + deps = [ + ":config", + ":env", + "//checker:type_check_issue", + "//checker:type_checker_builder", + "//checker:validation_result", + "//common:ast", + "//common:ast_proto", + "//common:constant", + "//common:decl", + "//common:expr", + "//common:type", + "//common:value", + "//compiler", + "//internal:proto_matchers", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:parser_interface", + "//runtime", + "//runtime:activation", + "//runtime:reference_resolver", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "env_runtime_test", + srcs = ["env_runtime_test.cc"], + deps = [ + ":config", + ":env", + ":env_runtime", + ":env_std_extensions", + ":env_yaml", + ":runtime_std_extensions", + "//checker:validation_result", + "//common:ast", + "//common:source", + "//common:value", + "//compiler", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "env_std_extensions_test", + srcs = ["env_std_extensions_test.cc"], + deps = [ + ":config", + ":env", + ":env_std_extensions", + "//compiler", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "env_yaml_test", + srcs = ["env_yaml_test.cc"], + deps = [ + ":config", + ":env_yaml", + "//common:constant", + "//internal:status_macros", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + ], +) + +cc_test( + name = "runtime_std_extensions_test", + srcs = ["runtime_std_extensions_test.cc"], + deps = [ + ":config", + ":env", + ":env_runtime", + ":env_std_extensions", + ":runtime_std_extensions", + "//checker:optional", + "//checker:validation_result", + "//common:ast", + "//common:value", + "//compiler", + "//extensions:lists_functions", + "//extensions:math_ext_decls", + "//extensions:strings", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/env/config.cc b/env/config.cc new file mode 100644 index 000000000..ccb4de34c --- /dev/null +++ b/env/config.cc @@ -0,0 +1,190 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/config.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/overload.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "internal/status_macros.h" + +namespace cel { + +namespace { + +const char* ConstantKindToTypeName(const ConstantKind& kind) { + return std::visit(absl::Overload{ + [](const std::monostate& arg) { return "dyn"; }, + [](const std::nullptr_t& arg) { return "null"; }, + [](bool arg) { return "bool"; }, + [](int64_t arg) { return "int"; }, + [](uint64_t arg) { return "uint"; }, + [](double arg) { return "double"; }, + [](const BytesConstant& arg) { return "bytes"; }, + [](const StringConstant& arg) { return "string"; }, + [](absl::Duration arg) { return "duration"; }, + [](absl::Time arg) { return "timestamp"; }, + }, + kind); +} +} // namespace + +absl::Status Config::AddExtensionConfig(std::string name, int version) { + for (const ExtensionConfig& extension_config : extension_configs_) { + if (extension_config.name == name) { + if (extension_config.version == version) { + return absl::OkStatus(); + } + return absl::AlreadyExistsError(absl::StrCat( + "Extension '", name, "' version ", extension_config.version, + " is already included. Cannot also include version ", version)); + } + } + extension_configs_.push_back( + ExtensionConfig{.name = std::move(name), .version = version}); + return absl::OkStatus(); +} + +absl::Status Config::SetStandardLibraryConfig( + const Config::StandardLibraryConfig& standard_library_config) { + if (!standard_library_config.included_macros.empty() && + !standard_library_config.excluded_macros.empty()) { + return absl::InvalidArgumentError( + "Cannot set both included and excluded macros."); + } + + if (!standard_library_config.included_functions.empty() && + !standard_library_config.excluded_functions.empty()) { + return absl::InvalidArgumentError( + "Cannot set both included and excluded functions."); + } + + absl::flat_hash_set included_function_names; + for (const auto& function : standard_library_config.included_functions) { + if (function.second.empty()) { + included_function_names.insert(function.first); + } + } + for (const auto& function : standard_library_config.included_functions) { + if (included_function_names.contains(function.first) && + !function.second.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Cannot include function '", function.first, + "' and also its specific overload '", function.second, "'")); + } + } + + absl::flat_hash_set excluded_function_names; + for (const auto& function : standard_library_config.excluded_functions) { + if (function.second.empty()) { + excluded_function_names.insert(function.first); + } + } + for (const auto& function : standard_library_config.excluded_functions) { + if (excluded_function_names.contains(function.first) && + !function.second.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Cannot exclude function '", function.first, + "' and also its specific overload '", function.second, "'")); + } + } + + standard_library_config_ = standard_library_config; + return absl::OkStatus(); +} + +absl::Status Config::AddVariableConfig(const VariableConfig& variable_config) { + for (const VariableConfig& existing_variable_config : variable_configs_) { + if (existing_variable_config.name == variable_config.name) { + return absl::AlreadyExistsError(absl::StrCat( + "Variable '", variable_config.name, "' is already included.")); + } + } + if (variable_config.value.has_value()) { + absl::string_view constant_type_name = + ConstantKindToTypeName(variable_config.value.kind()); + if (constant_type_name != variable_config.type_info.name) { + return absl::InvalidArgumentError( + absl::StrCat("Variable '", variable_config.name, "' has type ", + variable_config.type_info.name, + " but is assigned a constant value of type ", + constant_type_name, ".")); + } + } + variable_configs_.push_back(variable_config); + return absl::OkStatus(); +} + +absl::Status Config::ValidateFunctionConfig( + const FunctionConfig& function_config) { + for (const auto& overload : function_config.overload_configs) { + if (overload.is_member_function && overload.parameters.empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Function '", function_config.name, "' overload '", + overload.overload_id, + "' is marked as a member function but has no parameters. Member " + "functions must have at least one parameter (target).")); + } + } + return absl::OkStatus(); +} + +absl::Status Config::AddFunctionConfig(const FunctionConfig& function_config) { + CEL_RETURN_IF_ERROR(ValidateFunctionConfig(function_config)); + function_configs_.push_back(function_config); + return absl::OkStatus(); +} + +std::ostream& operator<<(std::ostream& os, + const Config::StandardLibraryConfig& config) { + os << "StandardLibraryConfig("; + if (!config.included_macros.empty()) { + os << "\n included_macros=" << absl::StrJoin(config.included_macros, ", "); + } + if (!config.excluded_macros.empty()) { + os << "\n excluded_macros=" << absl::StrJoin(config.excluded_macros, ", "); + } + if (!config.included_functions.empty()) { + os << "\n included_functions=" + << absl::StrJoin(config.included_functions, ", ", + [](std::string* out, + const std::pair& p) { + absl::StrAppend(out, p.first, ":", p.second); + }); + } + if (!config.excluded_functions.empty()) { + os << "\n excluded_functions=" + << absl::StrJoin(config.excluded_functions, ", ", + [](std::string* out, + const std::pair& p) { + absl::StrAppend(out, p.first, ":", p.second); + }); + } + os << "\n)"; + return os; +} + +} // namespace cel diff --git a/env/config.h b/env/config.h new file mode 100644 index 000000000..10b23d030 --- /dev/null +++ b/env/config.h @@ -0,0 +1,160 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ +#define THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "common/constant.h" + +namespace cel { + +class Config { + public: + void SetName(std::string name) { name_ = std::move(name); } + std::string GetName() const { return name_; } + + struct ContainerConfig { + std::string name; + // TODO(uncreated-issue/87): add support for aliases and abbreviations. + + bool IsEmpty() const { return name.empty(); } + }; + + void SetContainerConfig(ContainerConfig container_config) { + container_config_ = std::move(container_config); + } + + const ContainerConfig& GetContainerConfig() const { + return container_config_; + } + + struct ExtensionConfig { + static constexpr int kLatest = std::numeric_limits::max(); + + std::string name; + int version = kLatest; + }; + + absl::Status AddExtensionConfig(std::string name, + int version = ExtensionConfig::kLatest); + + const std::vector& GetExtensionConfigs() const { + return extension_configs_; + } + + struct StandardLibraryConfig { + // Exclude the entire standard library. + bool disable = false; + + // Exclude all standard library macros. + bool disable_macros = false; + + // Either included or excluded macros can be set, not both. If neither are + // set, all standard library macros are included. + absl::flat_hash_set included_macros; + absl::flat_hash_set excluded_macros; + + // Sets of pairs of function name and overload id to include or exclude. + // Either included or excluded functions can be set, not both. If neither + // are set, all standard library functions are included. + // If an overload is specified, only that overload is included or excluded. + // If no overload is specified (empty second element of pair), all overloads + // are included or excluded. + absl::flat_hash_set> included_functions; + absl::flat_hash_set> excluded_functions; + + bool IsEmpty() const { + return !disable && !disable_macros && included_macros.empty() && + excluded_macros.empty() && included_functions.empty() && + excluded_functions.empty(); + } + }; + + absl::Status SetStandardLibraryConfig( + const StandardLibraryConfig& standard_library_config); + + const StandardLibraryConfig& GetStandardLibraryConfig() const { + return standard_library_config_; + } + + struct TypeInfo { + std::string name; + std::vector params; + bool is_type_param = false; + }; + + struct VariableConfig { + std::string name; + std::string description; + TypeInfo type_info; + Constant value; + }; + + // Adds a variable config to the environment. The variable name and type + // are used by the CEL type checker to validate expressions. The variable + // value is used as an input value at runtime. + // + // Returns an error if a variable with the same name already exists, or if the + // type of the constant value does not match the specified type. + absl::Status AddVariableConfig(const VariableConfig& variable_config); + + const std::vector& GetVariableConfigs() const { + return variable_configs_; + } + + struct FunctionOverloadConfig { + std::string overload_id; + std::vector examples; + bool is_member_function = false; + std::vector parameters; + TypeInfo return_type; + }; + + struct FunctionConfig { + std::string name; + std::string description; + std::vector overload_configs; + }; + + absl::Status AddFunctionConfig(const FunctionConfig& function_config); + + const std::vector& GetFunctionConfigs() const { + return function_configs_; + } + + private: + std::string name_; + ContainerConfig container_config_; + std::vector extension_configs_; + StandardLibraryConfig standard_library_config_; + std::vector variable_configs_; + std::vector function_configs_; + + absl::Status ValidateFunctionConfig(const FunctionConfig& function_config); +}; + +std::ostream& operator<<(std::ostream& os, + const Config::StandardLibraryConfig& config); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_CONFIG_H_ diff --git a/env/config_test.cc b/env/config_test.cc new file mode 100644 index 000000000..df0d6f875 --- /dev/null +++ b/env/config_test.cc @@ -0,0 +1,222 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/config.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/constant.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::AllOf; +using ::testing::ElementsAre; +using ::testing::Field; +using ::testing::HasSubstr; +using ::testing::UnorderedElementsAre; + +TEST(EnvConfigTest, ExtensionConfigs) { + Config config; + ASSERT_THAT( + config.AddExtensionConfig("math", Config::ExtensionConfig::kLatest), + IsOk()); + ASSERT_THAT(config.AddExtensionConfig("optional", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("strings"), IsOk()); + + EXPECT_THAT(config.GetExtensionConfigs(), + UnorderedElementsAre( + AllOf(Field(&Config::ExtensionConfig::name, "math"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)), + AllOf(Field(&Config::ExtensionConfig::name, "optional"), + Field(&Config::ExtensionConfig::version, 2)), + AllOf(Field(&Config::ExtensionConfig::name, "strings"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)))); +} + +TEST(EnvConfigTest, ExtensionConfigConflict) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + ASSERT_THAT(config.AddExtensionConfig("math", 3), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +struct StandardLibraryConfigTestCase { + Config::StandardLibraryConfig standard_library_config; + std::string expected_error; // Empty if no error is expected. +}; + +class StandardLibraryConfigTest + : public testing::TestWithParam {}; + +TEST_P(StandardLibraryConfigTest, StandardLibraryConfig) { + const StandardLibraryConfigTestCase& param = GetParam(); + + Config config; + absl::Status status = + config.SetStandardLibraryConfig(param.standard_library_config); + if (param.expected_error.empty()) { + EXPECT_THAT(status, IsOk()); + } else { + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + } +} + +INSTANTIATE_TEST_SUITE_P( + StandardLibraryConfigTest, StandardLibraryConfigTest, + ::testing::Values( + StandardLibraryConfigTestCase{ + .standard_library_config = {}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_macros = {"all", "exists"}, + .excluded_macros = {"map", "filter"}, + }, + .expected_error = "Cannot set both included and excluded macros.", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + .excluded_functions = {{"_-_", ""}}, + }, + .expected_error = + "Cannot set both included and excluded functions.", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", ""}, {"_+_", "add_list"}}, + }, + .expected_error = "Cannot include function '_+_' and also its " + "specific overload 'add_list'", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", ""}, {"_+_", "add_list"}}, + }, + .expected_error = "Cannot exclude function '_+_' and also its " + "specific overload 'add_list'", + })); + +TEST(VariableConfigTest, VariableConfig) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = + { + .name = "mytype", + .params = {{.name = "int"}, {.name = "A", .is_type_param = true}}, + }, + }; + ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); + + ASSERT_EQ(config.GetVariableConfigs().size(), 1); + const auto& added_config = config.GetVariableConfigs()[0]; + EXPECT_EQ(added_config.type_info.name, "mytype"); + ASSERT_THAT(added_config.type_info.params.size(), 2); + EXPECT_EQ(added_config.type_info.params[0].name, "int"); + EXPECT_FALSE(added_config.type_info.params[0].is_type_param); + EXPECT_EQ(added_config.type_info.params[1].name, "A"); + EXPECT_TRUE(added_config.type_info.params[1].is_type_param); +} + +TEST(VariableConfigTest, VariableConfigConflict) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = {.name = "int"}, + }; + EXPECT_THAT(config.AddVariableConfig(variable_config), IsOk()); + EXPECT_THAT(config.AddVariableConfig(variable_config), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +TEST(VariableConfigTest, VariableConfigValueTypeMismatch) { + Config config; + Config::VariableConfig variable_config{ + .name = "test", + .type_info = {.name = "int"}, + .value = Constant(StringConstant("hello")), + }; + EXPECT_THAT(config.AddVariableConfig(variable_config), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Variable 'test' has type int but is assigned " + "a constant value of type string."))); +} + +TEST(FunctionConfigTest, FunctionConfig) { + Config config; + Config::FunctionConfig function_config; + function_config.name = "test"; + function_config.description = "Ultimate test"; + function_config.overload_configs.push_back(Config::FunctionOverloadConfig{ + .overload_id = "test_with_pill", + .examples = {"oracle.isTheOne('Neo', RED)"}, + .is_member_function = true, + .parameters = {{.name = "string"}, {.name = "Choice"}}, + .return_type = {.name = "bool"}, + }); + ASSERT_THAT(config.AddFunctionConfig(function_config), IsOk()); + ASSERT_EQ(config.GetFunctionConfigs().size(), 1); + const auto& added_config = config.GetFunctionConfigs()[0]; + EXPECT_EQ(added_config.name, "test"); + EXPECT_EQ(added_config.description, "Ultimate test"); + EXPECT_EQ(added_config.overload_configs.size(), 1); + + const auto& overload_config = added_config.overload_configs[0]; + EXPECT_EQ(overload_config.overload_id, "test_with_pill"); + EXPECT_THAT(overload_config.examples, + ElementsAre("oracle.isTheOne('Neo', RED)")); + EXPECT_TRUE(overload_config.is_member_function); + EXPECT_THAT( + overload_config.parameters, + ElementsAre(AllOf(Field(&Config::TypeInfo::name, "string"), + Field(&Config::TypeInfo::is_type_param, false)), + AllOf(Field(&Config::TypeInfo::name, "Choice"), + Field(&Config::TypeInfo::is_type_param, false)))); + EXPECT_THAT(overload_config.return_type, + Field(&Config::TypeInfo::name, "bool")); +} + +TEST(FunctionConfigTest, FunctionConfigInvalidMember) { + Config config; + Config::FunctionConfig function_config; + function_config.name = "test"; + function_config.overload_configs.push_back(Config::FunctionOverloadConfig{ + .overload_id = "test_member_no_params", + .is_member_function = true, + .parameters = {}, + }); + EXPECT_THAT(config.AddFunctionConfig(function_config), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("is marked as a member function but has no " + "parameters"))); +} + +} // namespace +} // namespace cel diff --git a/env/env.cc b/env/env.cc new file mode 100644 index 000000000..2c2555f14 --- /dev/null +++ b/env/env.cc @@ -0,0 +1,318 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env.h" + +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/type_checker_builder.h" +#include "common/constant.h" +#include "common/decl.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +bool ShouldIncludeMacro(const Config::StandardLibraryConfig& config, + absl::string_view macro) { + if (config.disable_macros) { + return false; + } + if (config.excluded_macros.contains(macro)) { + return false; + } + if (!config.included_macros.empty() && + !config.included_macros.contains(macro)) { + return false; + } + return true; +} + +bool ShouldIncludeFunction(const Config::StandardLibraryConfig& config, + absl::string_view function, + absl::string_view overload_id) { + if (config.excluded_functions.contains( + std::make_pair(std::string(function), std::string(overload_id))) || + config.excluded_functions.contains( + std::make_pair(std::string(function), ""))) { + return false; + } + if (!config.included_functions.empty() && + !config.included_functions.contains( + std::make_pair(std::string(function), "")) && + !config.included_functions.contains( + std::make_pair(std::string(function), std::string(overload_id)))) { + return false; + } + return true; +} + +absl::StatusOr MakeStdlibSubset( + const Config::StandardLibraryConfig& standard_library_config) { + CompilerLibrarySubset subset; + subset.library_id = "stdlib"; + // Capturing by reference is safe. The returned CompilerLibrarySubset's + // callbacks are only used during CompilerBuilder::Build() to configure + // contributed functions and macros. They are not retained by the constructed + // Compiler instance. The referenced config outlives the Build() call. + subset.should_include_macro = [&standard_library_config](const Macro& macro) { + return ShouldIncludeMacro(standard_library_config, macro.function()); + }; + subset.should_include_overload = [&standard_library_config]( + absl::string_view function, + absl::string_view overload_id) { + return ShouldIncludeFunction(standard_library_config, function, + overload_id); + }; + return subset; +} + +std::optional TypeNameToTypeKind(absl::string_view type_name) { + // Excluded types: + // kUnknown + // kError + // kTypeParam + // kFunction + // kEnum + + static const absl::NoDestructor< + absl::flat_hash_map> + kTypeNameToTypeKind({ + {"null", TypeKind::kNull}, + {"bool", TypeKind::kBool}, + {"int", TypeKind::kInt}, + {"uint", TypeKind::kUint}, + {"double", TypeKind::kDouble}, + {"string", TypeKind::kString}, + {"bytes", TypeKind::kBytes}, + {"timestamp", TypeKind::kTimestamp}, + {TimestampType::kName, TypeKind::kTimestamp}, + {"duration", TypeKind::kDuration}, + {DurationType::kName, TypeKind::kDuration}, + {"list", TypeKind::kList}, + {"map", TypeKind::kMap}, + {"", TypeKind::kDyn}, + {"any", TypeKind::kAny}, + {"dyn", TypeKind::kDyn}, + {BoolWrapperType::kName, TypeKind::kBoolWrapper}, + {IntWrapperType::kName, TypeKind::kIntWrapper}, + {UintWrapperType::kName, TypeKind::kUintWrapper}, + {DoubleWrapperType::kName, TypeKind::kDoubleWrapper}, + {StringWrapperType::kName, TypeKind::kStringWrapper}, + {BytesWrapperType::kName, TypeKind::kBytesWrapper}, + {"type", TypeKind::kType}, + }); + if (auto it = kTypeNameToTypeKind->find(type_name); + it != kTypeNameToTypeKind->end()) { + return it->second; + } + + return std::nullopt; +} + +absl::StatusOr TypeInfoToType( + const Config::TypeInfo& type_info, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool* descriptor_pool) { + if (type_info.is_type_param) { + return TypeParamType(type_info.name); + } + + std::optional type_kind = TypeNameToTypeKind(type_info.name); + if (!type_kind.has_value()) { + if (type_info.params.empty() && descriptor_pool != nullptr) { + const google::protobuf::Descriptor* type = + descriptor_pool->FindMessageTypeByName(type_info.name); + if (type != nullptr) { + return MessageType(type); + } + } + // TODO(uncreated-issue/88): use a TypeIntrospector to validate opaque types + std::vector parameter_types; + for (const Config::TypeInfo& param : type_info.params) { + CEL_ASSIGN_OR_RETURN(Type parameter_type, + TypeInfoToType(param, arena, descriptor_pool)); + parameter_types.push_back(parameter_type); + } + + return OpaqueType(arena, type_info.name, parameter_types); + } + + switch (*type_kind) { + case TypeKind::kNull: + return NullType(); + case TypeKind::kBool: + return BoolType(); + case TypeKind::kInt: + return IntType(); + case TypeKind::kUint: + return UintType(); + case TypeKind::kDouble: + return DoubleType(); + case TypeKind::kString: + return StringType(); + case TypeKind::kBytes: + return BytesType(); + case TypeKind::kDuration: + return DurationType(); + case TypeKind::kTimestamp: + return TimestampType(); + case TypeKind::kList: { + Type element_type; + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN( + element_type, + TypeInfoToType(type_info.params[0], arena, descriptor_pool)); + } else { + element_type = DynType(); + } + return ListType(arena, element_type); + } + case TypeKind::kMap: { + Type key_type = DynType(); + Type value_type = DynType(); + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN(key_type, TypeInfoToType(type_info.params[0], + arena, descriptor_pool)); + } + if (type_info.params.size() > 1) { + CEL_ASSIGN_OR_RETURN( + value_type, + TypeInfoToType(type_info.params[1], arena, descriptor_pool)); + } + return MapType(arena, key_type, value_type); + } + case TypeKind::kDyn: + return DynType(); + case TypeKind::kAny: + return AnyType(); + case TypeKind::kBoolWrapper: + return BoolWrapperType(); + case TypeKind::kIntWrapper: + return IntWrapperType(); + case TypeKind::kUintWrapper: + return UintWrapperType(); + case TypeKind::kDoubleWrapper: + return DoubleWrapperType(); + case TypeKind::kStringWrapper: + return StringWrapperType(); + case TypeKind::kBytesWrapper: + return BytesWrapperType(); + case TypeKind::kType: { + if (type_info.params.empty()) { + return TypeType(arena, DynType()); + } + CEL_ASSIGN_OR_RETURN(Type type, TypeInfoToType(type_info.params[0], arena, + descriptor_pool)); + return TypeType(arena, type); + } + default: + return DynType(); + } +} + +absl::StatusOr FunctionConfigToFunctionDecl( + const Config::FunctionConfig& function_config, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool* descriptor_pool) { + FunctionDecl function_decl; + function_decl.set_name(function_config.name); + for (const Config::FunctionOverloadConfig& overload_config : + function_config.overload_configs) { + OverloadDecl overload_decl; + overload_decl.set_id(overload_config.overload_id); + overload_decl.set_member(overload_config.is_member_function); + for (const Config::TypeInfo& parameter : overload_config.parameters) { + CEL_ASSIGN_OR_RETURN(Type parameter_type, + TypeInfoToType(parameter, arena, descriptor_pool)); + overload_decl.mutable_args().push_back(parameter_type); + } + CEL_ASSIGN_OR_RETURN( + Type return_type, + TypeInfoToType(overload_config.return_type, arena, descriptor_pool)); + overload_decl.set_result(return_type); + CEL_RETURN_IF_ERROR(function_decl.AddOverload(overload_decl)); + } + return function_decl; +} + +} // namespace + +absl::StatusOr> Env::NewCompiler() { + CEL_ASSIGN_OR_RETURN( + std::unique_ptr compiler_builder, + cel::NewCompilerBuilder(descriptor_pool_, compiler_options_)); + cel::TypeCheckerBuilder& checker_builder = + compiler_builder->GetCheckerBuilder(); + + checker_builder.set_container(config_.GetContainerConfig().name); + + if (!config_.GetStandardLibraryConfig().disable) { + CEL_RETURN_IF_ERROR( + compiler_builder->AddLibrary(StandardCompilerLibrary())); + CEL_ASSIGN_OR_RETURN(CompilerLibrarySubset standard_library_subset, + MakeStdlibSubset(config_.GetStandardLibraryConfig())); + CEL_RETURN_IF_ERROR( + compiler_builder->AddLibrarySubset(std::move(standard_library_subset))); + } + for (const Config::ExtensionConfig& extension_config : + config_.GetExtensionConfigs()) { + CEL_ASSIGN_OR_RETURN(CompilerLibrary library, + extension_registry_.GetCompilerLibrary( + extension_config.name, extension_config.version)); + CEL_RETURN_IF_ERROR(compiler_builder->AddLibrary(std::move(library))); + } + + google::protobuf::Arena* arena = checker_builder.arena(); + for (const Config::VariableConfig& variable_config : + config_.GetVariableConfigs()) { + VariableDecl variable_decl; + variable_decl.set_name(variable_config.name); + CEL_ASSIGN_OR_RETURN(Type type, + TypeInfoToType(variable_config.type_info, arena, + descriptor_pool_.get())); + variable_decl.set_type(type); + if (variable_config.value.has_value()) { + variable_decl.set_value(variable_config.value); + } + CEL_RETURN_IF_ERROR(checker_builder.AddVariable(variable_decl)); + } + + for (const Config::FunctionConfig& function_config : + config_.GetFunctionConfigs()) { + CEL_ASSIGN_OR_RETURN(FunctionDecl function_decl, + FunctionConfigToFunctionDecl(function_config, arena, + descriptor_pool_.get())); + CEL_RETURN_IF_ERROR(checker_builder.AddFunction(function_decl)); + } + + return compiler_builder->Build(); +} + +} // namespace cel diff --git a/env/env.h b/env/env.h new file mode 100644 index 000000000..f46e5947c --- /dev/null +++ b/env/env.h @@ -0,0 +1,71 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_H_ + +#include + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/internal/ext_registry.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Env class establishes the environment for compiling CEL expressions. +// +// It is used to configure compiler options, extension functions, and other +// customizable CEL features. +class Env { + public: + // Registers a `CompilerLibrary` with the environment. Note that the library + // does not automatically get added to a `Compiler`. `NewCompiler` relies + // on `Config` to determine which libraries to load. + void RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) { + extension_registry_.RegisterCompilerLibrary(name, alias, version, + std::move(library_factory)); + } + + void SetDescriptorPool( + std::shared_ptr descriptor_pool) { + descriptor_pool_ = std::move(descriptor_pool); + } + + const google::protobuf::DescriptorPool* GetDescriptorPool() const { + return descriptor_pool_.get(); + } + + void SetConfig(const Config& config) { config_ = config; } + + absl::StatusOr> NewCompiler(); + + private: + cel::env_internal::ExtensionRegistry extension_registry_; + std::shared_ptr descriptor_pool_; + CompilerOptions compiler_options_; + Config config_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_H_ diff --git a/env/env_runtime.cc b/env/env_runtime.cc new file mode 100644 index 000000000..09bbcde04 --- /dev/null +++ b/env/env_runtime.cc @@ -0,0 +1,77 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_runtime.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_builder_factory.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_functions.h" + +namespace cel { + +absl::StatusOr EnvRuntime::CreateRuntimeBuilder() { + const std::vector& extension_configs = + config_.GetExtensionConfigs(); + const Config::ExtensionConfig* optional_extension_config = nullptr; + for (const Config::ExtensionConfig& extension_config : extension_configs) { + if (extension_config.name == "optional") { + optional_extension_config = &extension_config; + runtime_options_.enable_qualified_type_identifiers = true; + break; + } + } + + CEL_ASSIGN_OR_RETURN( + RuntimeBuilder runtime_builder, + cel::CreateRuntimeBuilder(descriptor_pool_, runtime_options_)); + + if (!config_.GetStandardLibraryConfig().disable) { + CEL_RETURN_IF_ERROR(RegisterStandardFunctions( + runtime_builder.function_registry(), runtime_options_)); + } + + // Register optional extension functions first, because other extensions + // depend on it (e.g. regex). + if (optional_extension_config != nullptr) { + CEL_RETURN_IF_ERROR(extension_registry_.RegisterExtensionFunctions( + runtime_builder, runtime_options_, optional_extension_config->name, + optional_extension_config->version)); + } + + for (const Config::ExtensionConfig& extension_config : extension_configs) { + if (&extension_config == optional_extension_config) { + continue; + } + CEL_RETURN_IF_ERROR(extension_registry_.RegisterExtensionFunctions( + runtime_builder, runtime_options_, extension_config.name, + extension_config.version)); + } + return runtime_builder; +} + +absl::StatusOr> EnvRuntime::NewRuntime() { + CEL_ASSIGN_OR_RETURN(RuntimeBuilder runtime_builder, CreateRuntimeBuilder()); + return std::move(runtime_builder).Build(); +} + +} // namespace cel diff --git a/env/env_runtime.h b/env/env_runtime.h new file mode 100644 index 000000000..ff62ec1d4 --- /dev/null +++ b/env/env_runtime.h @@ -0,0 +1,73 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "env/config.h" +#include "env/internal/runtime_ext_registry.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// EnvRuntime class establishes the environment for creating CEL runtimes. +// +// It is used to configure runtime options, extension functions, and other +// customizable CEL runtime features. +// +// EnvRuntime is separate from Env to avoid a dependency on the compiler for +// binaries that only use the runtime. +// +// Even though EnvRuntime is separate from Env, the Config and DescriptorPool +// passed to EnvRuntime are expected to be the same as those passed to Env for +// compilation. This ensures consistency between compilation and runtime. +class EnvRuntime { + public: + void SetDescriptorPool( + std::shared_ptr descriptor_pool) { + descriptor_pool_ = std::move(descriptor_pool); + } + + void SetConfig(const Config& config) { config_ = config; } + + RuntimeOptions& mutable_runtime_options() { return runtime_options_; } + + absl::StatusOr CreateRuntimeBuilder(); + + // Shortcut for CreateRuntimeBuilder() followed by Build(). + absl::StatusOr> NewRuntime(); + + private: + cel::env_internal::RuntimeExtensionRegistry& GetRuntimeExtensionRegistry() { + return extension_registry_; + } + + friend void RegisterStandardExtensions(EnvRuntime& env_runtime); + + cel::env_internal::RuntimeExtensionRegistry extension_registry_; + std::shared_ptr descriptor_pool_; + Config config_; + RuntimeOptions runtime_options_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_RUNTIME_H_ diff --git a/env/env_runtime_test.cc b/env/env_runtime_test.cc new file mode 100644 index 000000000..1c4205224 --- /dev/null +++ b/env/env_runtime_test.cc @@ -0,0 +1,160 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_runtime.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/source.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_std_extensions.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::IsEmpty; +using ::testing::ValuesIn; + +struct TestCase { + std::string config_yaml; + std::string expr; + bool expected_to_fail = false; +}; + +class EnvRuntimeTest : public testing::TestWithParam {}; + +TEST_P(EnvRuntimeTest, EndToEnd) { + const TestCase& param = GetParam(); + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(param.config_yaml)); + + Env env; + env.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env); + env.SetConfig(config); + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + std::unique_ptr ast; + if (!param.expected_to_fail) { + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile(param.expr)); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(ast, result.ReleaseAst()); + } else { + // Bypass type checking to allow compilation to succeed since we expect the + // runtime to fail. + ASSERT_OK_AND_ASSIGN(std::unique_ptr source, + NewSource(param.expr, "")); + ASSERT_OK_AND_ASSIGN(ast, compiler->GetParser().Parse(*source)); + } + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + + absl::StatusOr> program_or = + runtime->CreateProgram(std::move(ast)); + if (param.expected_to_fail) { + EXPECT_THAT(program_or, StatusIs(absl::StatusCode::kInvalidArgument)) + << " expr: " << param.expr; + return; + } + + ASSERT_THAT(program_or, IsOk()) << " expr: " << param.expr; + + std::unique_ptr program = *std::move(program_or); + ASSERT_NE(program, nullptr); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()) << " expr: " << param.expr; +} + +std::vector GetEnvRuntimeTestCases() { + return { + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8='", + }, + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + - name: "optional" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "optional.of(1).hasValue()", + }, + TestCase{ + .config_yaml = R"yaml( + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "optional.of(1).hasValue()", + .expected_to_fail = true, + }, + TestCase{ + .config_yaml = R"yaml( + stdlib: + disable: true + )yaml", + .expr = "1 + 2 == 3", + .expected_to_fail = true, + }, + TestCase{ + .config_yaml = R"yaml( + stdlib: + disable: true + extensions: + - name: "encoders" + )yaml", + .expr = "base64.encode(b'hello') == 'aGVsbG8=' && " + "1 + 2 == 3", + .expected_to_fail = true, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvRuntimeTest, EnvRuntimeTest, + ValuesIn(GetEnvRuntimeTestCases())); + +} // namespace +} // namespace cel diff --git a/env/env_std_extensions.cc b/env/env_std_extensions.cc new file mode 100644 index 000000000..f2041b979 --- /dev/null +++ b/env/env_std_extensions.cc @@ -0,0 +1,76 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_std_extensions.h" + +#include "checker/optional.h" +#include "compiler/optional.h" +#include "env/env.h" +#include "extensions/bindings_ext.h" +#include "extensions/comprehensions_v2.h" +#include "extensions/encoders.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext_decls.h" +#include "extensions/proto_ext.h" +#include "extensions/regex_ext.h" +#include "extensions/sets_functions.h" +#include "extensions/strings.h" + +namespace cel { + +void RegisterStandardExtensions(Env& env) { + env.RegisterCompilerLibrary("cel.lib.ext.bindings", "bindings", 0, []() { + return extensions::BindingsCompilerLibrary(); + }); + env.RegisterCompilerLibrary("cel.lib.ext.encoders", "encoders", 0, []() { + return extensions::EncodersCompilerLibrary(); + }); + for (int version = 0; version <= extensions::kListsExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.lists", "lists", version, + [version]() { return extensions::ListsCompilerLibrary(version); }); + } + for (int version = 0; version <= extensions::kMathExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.math", "math", version, + [version]() { return extensions::MathCompilerLibrary(version); }); + } + for (int version = 0; version <= kOptionalExtensionLatestVersion; ++version) { + env.RegisterCompilerLibrary("optional", "", version, [version]() { + return OptionalCompilerLibrary(version); + }); + } + env.RegisterCompilerLibrary("cel.lib.ext.protos", "protos", 0, []() { + return extensions::ProtoExtCompilerLibrary(); + }); + env.RegisterCompilerLibrary("cel.lib.ext.sets", "sets", 0, []() { + return extensions::SetsCompilerLibrary(); + }); + for (int version = 0; version <= extensions::kStringsExtensionLatestVersion; + ++version) { + env.RegisterCompilerLibrary( + "cel.lib.ext.strings", "strings", version, + [version]() { return extensions::StringsCompilerLibrary(version); }); + } + env.RegisterCompilerLibrary( + "cel.lib.ext.comprev2", "two-var-comprehensions", 0, + []() { return extensions::ComprehensionsV2CompilerLibrary(); }); + env.RegisterCompilerLibrary("cel.lib.ext.regex", "regex", 0, []() { + return extensions::RegexExtCompilerLibrary(); + }); +} + +} // namespace cel diff --git a/env/env_std_extensions.h b/env/env_std_extensions.h new file mode 100644 index 000000000..79cf37dbf --- /dev/null +++ b/env/env_std_extensions.h @@ -0,0 +1,42 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ + +#include "env/env.h" + +namespace cel { + +// Registers the standard CEL extensions with the given environment. This makes +// them available, but does not enable them. See Env::Config for how to enable +// extensions. +// +// Extensions are registered under the following names: +// +// - cel.lib.ext.bindings (alias: "bindings") +// - cel.lib.ext.encoders (alias: "encoders") +// - cel.lib.ext.lists (alias: "lists") +// - cel.lib.ext.math (alias: "math") +// - optional +// - cel.lib.ext.protos (alias: "protos") +// - cel.lib.ext.sets (alias: "sets") +// - cel.lib.ext.strings (alias: "strings") +// - cel.lib.ext.comprev2 (alias: "two-var-comprehensions") +// - cel.lib.ext.regex (alias: "regex") +void RegisterStandardExtensions(Env& env); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_STD_EXTENSIONS_H_ diff --git a/env/env_std_extensions_test.cc b/env/env_std_extensions_test.cc new file mode 100644 index 000000000..7d9572cc0 --- /dev/null +++ b/env/env_std_extensions_test.cc @@ -0,0 +1,116 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_std_extensions.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::testing::TestWithParam; + +struct TestCase { + std::string extension; + std::string expr; +}; + +class EnvStdExtensions : public testing::TestWithParam {}; + +TEST_P(EnvStdExtensions, RegistrationTest) { + const TestCase& param = GetParam(); + + Env env; + RegisterStandardExtensions(env); + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + + Config config; + ASSERT_THAT(config.AddExtensionConfig(param.extension), IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(param.expr)); + ASSERT_TRUE(result.IsValid()) << "Expected no issues for expr: " << param.expr + << " but got: " << result.FormatError(); +} + +INSTANTIATE_TEST_SUITE_P( + RegistrationTest, EnvStdExtensions, + ::testing::Values( + TestCase{ + .extension = "cel.lib.ext.bindings", // official name + .expr = "cel.bind(t, true, t)", + }, + TestCase{ + .extension = "bindings", // alias + .expr = "cel.bind(t, true, t)", + }, + TestCase{ + .extension = "encoders", + .expr = "base64.encode(b'hello')", + }, + TestCase{ + .extension = "lists", + .expr = "[1, 2, 3].sort()", + }, + TestCase{ + .extension = "lists", + .expr = "['a'].sortBy(e, e)", + }, + TestCase{ + .extension = "math", + .expr = "math.sqrt(-1)", + }, + TestCase{ + .extension = "optional", + .expr = "[1, 2].first()", + }, + TestCase{ + .extension = "optional", + .expr = "[0][?1]", // optional syntax auto-enabled + }, + TestCase{ + .extension = "protos", + .expr = "!proto.hasExt(cel.expr.conformance.proto2.TestAllTypes{}, " + "cel.expr.conformance.proto2.nested_ext)", + }, + TestCase{ + .extension = "sets", + .expr = "sets.contains([1], [1])", + }, + TestCase{ + .extension = "strings", + .expr = "'foo'.reverse()", + }, + TestCase{ + .extension = "two-var-comprehensions", + .expr = "[1, 2, 3, 4].all(i, v, i < v)", + }, + TestCase{ + .extension = "regex", + .expr = "regex.replace('abc', '$', '_end')", + })); + +} // namespace +} // namespace cel diff --git a/env/env_test.cc b/env/env_test.cc new file mode 100644 index 000000000..dcd2d97fa --- /dev/null +++ b/env/env_test.cc @@ -0,0 +1,713 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env.h" + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/constant.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/type.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "internal/proto_matchers.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/parser_interface.h" +#include "runtime/activation.h" +#include "runtime/reference_resolver.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::cel::internal::test::EqualsProto; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::NotNull; +using ::testing::Property; +using ::testing::UnorderedElementsAre; +using ::testing::Values; +using ::testing::ValuesIn; + +Expr TestMacroExpander(MacroExprFactory& factory, absl::Span args) { + return factory.NewStringConst("Hello"); +} + +class TestLibrary : public CompilerLibrary { + public: + explicit TestLibrary(int version) + : CompilerLibrary( + "testlib", + [version](ParserBuilder& builder) { + absl::Status status; + CEL_ASSIGN_OR_RETURN( + auto macro1, + cel::Macro::Global("testMacro1", 0, TestMacroExpander)); + status.Update(builder.AddMacro(macro1)); + if (version == 2) { + CEL_ASSIGN_OR_RETURN( + auto macro2, + cel::Macro::Global("testMacro2", 0, TestMacroExpander)); + status.Update(builder.AddMacro(macro2)); + } + return status; + }, + [version](TypeCheckerBuilder& builder) { + absl::Status status; + CEL_ASSIGN_OR_RETURN( + auto func1, cel::MakeFunctionDecl( + "testFunc1", MakeOverloadDecl(StringType()))); + status.Update(builder.AddFunction(func1)); + if (version == 2) { + CEL_ASSIGN_OR_RETURN( + auto func2, + cel::MakeFunctionDecl("testFunc2", + MakeOverloadDecl(StringType()))); + status.Update(builder.AddFunction(func2)); + } + return status; + }) {}; +}; + +absl::StatusOr CompileAndEvalExpr( + Env& env, absl::string_view expr, + const Activation& activation = Activation()) { + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler, env.NewCompiler()); + if (compiler == nullptr) { + return absl::InternalError("Failed to create compiler"); + } + CEL_ASSIGN_OR_RETURN(ValidationResult result, compiler->Compile(expr)); + if (!result.GetIssues().empty()) { + return absl::InvalidArgumentError(result.FormatError()); + } + + cel::RuntimeOptions opts; + CEL_ASSIGN_OR_RETURN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder(env.GetDescriptorPool(), opts)); + CEL_RETURN_IF_ERROR(cel::EnableReferenceResolver( + rt_builder, cel::ReferenceResolverEnabled::kAlways)); + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + if (runtime == nullptr) { + return absl::InternalError("Failed to create runtime"); + } + + CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, result.ReleaseAst()); + if (ast == nullptr) { + return absl::InternalError("Failed to create AST"); + } + google::protobuf::Arena arena; + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + if (program == nullptr) { + return absl::InternalError("Failed to create program"); + } + CEL_ASSIGN_OR_RETURN(Value value, program->Evaluate(&arena, activation)); + return value; +} + +absl::StatusOr CompileAndEvalBooleanExpr( + Env& env, absl::string_view expr, + const Activation& activation = Activation()) { + CEL_ASSIGN_OR_RETURN(auto value, CompileAndEvalExpr(env, expr, activation)); + return value.GetBool(); +} + +class LibraryConfigTest : public testing::Test { + protected: + void SetUp() override { + env_.RegisterCompilerLibrary("testlib", "ml", 1, + []() { return TestLibrary(1); }); + env_.RegisterCompilerLibrary("testlib", "ml", 2, + []() { return TestLibrary(2); }); + env_.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + } + + Env env_; +}; + +TEST_F(LibraryConfigTest, DefaultVersion) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("testlib"), IsOk()); + + env_.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env_.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile("testMacro1()")); + ASSERT_OK_AND_ASSIGN(auto result2, compiler->Compile("testFunc1()")); + ASSERT_OK_AND_ASSIGN(auto result3, compiler->Compile("testMacro2()")); + ASSERT_OK_AND_ASSIGN(auto result4, compiler->Compile("testFunc2()")); + + EXPECT_THAT(result1.GetIssues(), IsEmpty()); + EXPECT_THAT(result2.GetIssues(), IsEmpty()); + EXPECT_THAT(result3.GetIssues(), IsEmpty()); + EXPECT_THAT(result4.GetIssues(), IsEmpty()); +} + +TEST_F(LibraryConfigTest, SpecificVersion) { + Config config; + ASSERT_THAT(config.AddExtensionConfig("testlib", 1), IsOk()); + + env_.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env_.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile("testMacro1()")); + ASSERT_OK_AND_ASSIGN(auto result2, compiler->Compile("testFunc1()")); + ASSERT_OK_AND_ASSIGN(auto result3, compiler->Compile("testMacro2()")); + ASSERT_OK_AND_ASSIGN(auto result4, compiler->Compile("testFunc2()")); + + EXPECT_THAT(result1.GetIssues(), IsEmpty()); + EXPECT_THAT(result2.GetIssues(), IsEmpty()); + EXPECT_THAT(result3.GetIssues(), + UnorderedElementsAre( + Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference to 'testMacro2'")))); + EXPECT_THAT(result4.GetIssues(), + UnorderedElementsAre( + Property(&TypeCheckIssue::message, + HasSubstr("undeclared reference to 'testFunc2'")))); +} + +struct StandardLibraryConfigTestCase { + Config::StandardLibraryConfig standard_library_config; + std::vector expected_valid_expressions; + std::vector expected_invalid_expressions; +}; + +class StandardLibraryConfigTest + : public testing::TestWithParam {}; + +TEST_P(StandardLibraryConfigTest, StandardLibraryConfig) { + const StandardLibraryConfigTestCase& param = GetParam(); + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + + Config config; + ASSERT_THAT(config.SetStandardLibraryConfig(param.standard_library_config), + IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + for (const std::string& expr : param.expected_valid_expressions) { + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile(expr)); + EXPECT_THAT(result1.GetIssues(), IsEmpty()) + << "With config: " << param.standard_library_config + << ", expected no issues for expr: " << expr + << " but got: " << result1.FormatError(); + } + for (const std::string& expr : param.expected_invalid_expressions) { + ASSERT_OK_AND_ASSIGN(auto result1, compiler->Compile(expr)); + EXPECT_THAT(result1.GetIssues(), Not(IsEmpty())) + << "With config: " << param.standard_library_config + << ", expected compilation error for expr: " << expr << " but got: \'" + << result1.FormatError() << "\'"; + } +} + +INSTANTIATE_TEST_SUITE_P( + StandardLibraryConfigTest, StandardLibraryConfigTest, + Values( + StandardLibraryConfigTestCase{ + .standard_library_config = {}, + .expected_valid_expressions = {"1 + 2", + "[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.disable = true}, + .expected_invalid_expressions = {"1 + 2", + "[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.disable_macros = true}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3].exists(x, x == 1)", + "[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.excluded_macros = {"map", "all"}}, + .expected_valid_expressions = {"[1, 2, 3].exists(x, x == 1)"}, + .expected_invalid_expressions = {"[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.included_macros = {"map", "all"}}, + .expected_valid_expressions = {"[1, 2, 3].all(x, x == 1)", + "[1, 2, 3].map(x, x)"}, + .expected_invalid_expressions = {"[1, 2, 3].exists(x, x == 1)"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.excluded_functions = {{"_+_", ""}}}, + .expected_invalid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.excluded_functions = {{"_+_", "add_bytes"}, + {"_+_", "add_list"}, + {"_+_", "add_string"}}}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = {.included_functions = {{"_+_", ""}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]"}, + .expected_invalid_expressions = {"'hello' + 'world'"}, + })); + +TEST(ContainerConfigTest, ContainerConfig) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig({.name = "cel.expr.conformance.proto2"}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("TestAllTypes{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + +struct TypeInfoTestCase { + Config::TypeInfo type_info; + std::string expected_type_pb; +}; + +using TypeInfoTest = testing::TestWithParam; + +TEST_P(TypeInfoTest, TypeInfo) { + const TypeInfoTestCase& param = GetParam(); + cel::expr::Type expected_type_pb; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(param.expected_type_pb, + &expected_type_pb)); + + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + Config::VariableConfig variable_config; + variable_config.name = "test"; + variable_config.type_info = param.type_info; + ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_THAT(compiler, NotNull()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("test")); + EXPECT_THAT(result.GetIssues(), IsEmpty()) + << " error: " << result.FormatError(); + + // Obtain the inferred return type of the expression `test`. + const Ast* ast = result.GetAst(); + ASSERT_THAT(ast, NotNull()); + cel::expr::CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*ast, &checked_expr), IsOk()); + auto it = checked_expr.type_map().find(checked_expr.expr().id()); + ASSERT_NE(it, checked_expr.type_map().end()); + + cel::expr::Type actual_type_pb = it->second; + EXPECT_THAT(actual_type_pb, EqualsProto(expected_type_pb)); +} + +std::vector GetTypeInfoTestCases() { + return { + TypeInfoTestCase{ + .type_info = {.name = "int"}, + .expected_type_pb = "primitive: INT64", + }, + TypeInfoTestCase{ + .type_info = {.name = "list", + .params = {Config::TypeInfo{.name = "int"}}}, + .expected_type_pb = "list_type { elem_type { primitive: INT64 } }", + }, + TypeInfoTestCase{ + .type_info = {.name = "list"}, + .expected_type_pb = "list_type { elem_type { dyn {} }}", + }, + TypeInfoTestCase{ + .type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "int"}}}, + .expected_type_pb = "map_type { key_type { primitive: STRING } " + "value_type { primitive: INT64 }}", + }, + TypeInfoTestCase{ + .type_info = {.name = "cel.expr.conformance.proto2.TestAllTypes"}, + .expected_type_pb = + "message_type: 'cel.expr.conformance.proto2.TestAllTypes'", + }, + TypeInfoTestCase{ + .type_info = {.name = "A", + .params = {Config::TypeInfo{.name = "B", + .is_type_param = true}}}, + // TypeParam is replaced with dyn by the type checker. + .expected_type_pb = + "abstract_type { name: 'A' parameter_types { dyn {} } }", + }, + TypeInfoTestCase{ + .type_info = {.name = "any"}, + .expected_type_pb = "well_known: ANY", + }, + TypeInfoTestCase{ + .type_info = {.name = "timestamp"}, + .expected_type_pb = "well_known: TIMESTAMP", + }, + TypeInfoTestCase{ + .type_info = {.name = "google.protobuf.DoubleValue"}, + .expected_type_pb = "wrapper: DOUBLE", + }, + TypeInfoTestCase{ + .type_info = {.name = "type", + .params = {Config::TypeInfo{.name = "duration"}}}, + .expected_type_pb = "type: { well_known: DURATION }", + }, + TypeInfoTestCase{ + .type_info = {.name = "parameterized", + .params = {{.name = "A", .is_type_param = true}, + {.name = "double"}}}, + // TypeParam is replaced with dyn by the type checker. + .expected_type_pb = "abstract_type { name: 'parameterized' " + "parameter_types { dyn {} } " + "parameter_types { primitive: DOUBLE } }", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(VariableConfigTest, TypeInfoTest, + ValuesIn(GetTypeInfoTestCases())); + +struct VariableConfigWithValueTestCase { + Config::VariableConfig variable_config; + std::string validate_type_expr; + std::string validate_value_expr; +}; + +class VariableConfigWithValueTest + : public testing::TestWithParam {}; + +TEST_P(VariableConfigWithValueTest, VariableConfigWithValue) { + const VariableConfigWithValueTestCase& param = GetParam(); + + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + ASSERT_THAT(config.AddVariableConfig(param.variable_config), IsOk()); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN( + bool type_as_expected, + CompileAndEvalBooleanExpr(env, param.validate_type_expr)); + ASSERT_TRUE(type_as_expected) << " expr: " << param.validate_type_expr; + if (!param.validate_value_expr.empty()) { + ASSERT_OK_AND_ASSIGN( + bool value_as_expected, + CompileAndEvalBooleanExpr(env, param.validate_value_expr)); + ASSERT_TRUE(value_as_expected) << " expr: " << param.validate_value_expr; + } +} + +Config::VariableConfig MakeConstant( + absl::string_view variable_name, absl::string_view type_name, + absl::AnyInvocable setter) { + Config::VariableConfig variable_config; + variable_config.name = variable_name; + Constant c; + setter(c); + variable_config.type_info.name = type_name; + variable_config.value = c; + return variable_config; +} + +std::vector +GetVariableConfigWithValueTestCases() { + return { + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "null", [](auto& c) { c.set_null_value(nullptr); }), + .validate_type_expr = "type(x) == type(null)", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "bool", [](auto& c) { c.set_bool_value(true); }), + .validate_type_expr = "type(x) == bool", + .validate_value_expr = "x == true", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "int", [](Constant& c) { c.set_int_value(42); }), + .validate_type_expr = "type(x) == int", + .validate_value_expr = "x == 42", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "uint", [](Constant& c) { c.set_uint_value(777); }), + .validate_type_expr = "type(x) == uint", + .validate_value_expr = "x == 777u", + }, + VariableConfigWithValueTestCase{ + .variable_config = + MakeConstant("x", "double", + [](Constant& c) { c.set_double_value(1.0 / 3.0); }), + .validate_type_expr = "type(x) == double", + .validate_value_expr = "x > 0.333 && x < 0.334", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant("x", "bytes", + [](Constant& c) { + c.set_bytes_value(absl::string_view( + "\xff\x00\x01", 3)); + }), + .validate_type_expr = "type(x) == bytes", + .validate_value_expr = "x == b'\\xff\\x00\\x01'", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "string", [](Constant& c) { c.set_string_value("hello"); }), + .validate_type_expr = "type(x) == string", + .validate_value_expr = "x == 'hello'", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "timestamp", + [](Constant& c) { + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + c.set_timestamp_value(absl::FromUnixSeconds(1767323045)); + }), + .validate_type_expr = + "type(x) == type(timestamp('2026-01-02T03:04:05Z'))", + .validate_value_expr = "x == timestamp('2026-01-02T03:04:05Z')", + }, + VariableConfigWithValueTestCase{ + .variable_config = MakeConstant( + "x", "duration", + [](Constant& c) { + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + c.set_duration_value(absl::Hours(1) + absl::Minutes(2) + + absl::Seconds(3)); + }), + .validate_type_expr = "type(x) == type(duration('1h2m3s'))", + .validate_value_expr = "x == duration('1h2m3s')", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(VariableConfigTest, VariableConfigWithValueTest, + ValuesIn(GetVariableConfigWithValueTestCases())); + +struct FunctionConfigTestCase { + Config::FunctionConfig function_config; + std::vector variable_configs; + std::string expr; + std::string expected_error; +}; + +class FunctionConfigTest + : public testing::TestWithParam {}; + +TEST_P(FunctionConfigTest, FunctionConfig) { + const FunctionConfigTestCase& param = GetParam(); + + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + for (const Config::VariableConfig& variable_config : param.variable_configs) { + ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); + } + ASSERT_THAT(config.AddFunctionConfig(param.function_config), IsOk()); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(param.expr)); + if (param.expected_error.empty()) { + EXPECT_TRUE(result.GetIssues().empty()) + << " expr: " << param.expr << " error: " << result.FormatError(); + } else { + EXPECT_THAT(result.GetIssues(), + UnorderedElementsAre(Property(&TypeCheckIssue::message, + HasSubstr(param.expected_error)))) + << " expr: " << param.expr << " error: " << result.FormatError(); + } +} + +std::vector GetFunctionConfigTestCases() { + return {{ + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "plus(int,int)", + .examples = {"add(1, 2) -> 3"}, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "int"}, + }, + }, + }, + .expr = "add(1, 2)", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "int.plus(int)", + .examples = {"1.add(2) -> 3"}, + .is_member_function = true, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "int"}, + }, + }, + }, + .expr = "1.add(2) == 3", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "plus(string,string)", + .examples = + {"add('hello', 'world') -> 'hello world'"}, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "string"}, + }, + }, + }, + .expr = "add('hello', 'world')", + .expected_error = "found no matching overload for 'add' applied to " + "'(string, string)'", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "add", + .overload_configs = + { + { + .overload_id = "int.plus(int)", + .examples = {"1.add(2) -> 'three'"}, + .is_member_function = true, + .parameters = {{.name = "int"}, {.name = "int"}}, + .return_type = {.name = "string"}, + }, + }, + }, + .expr = "1.add(2) == 3", + .expected_error = "found no matching overload for '_==_' applied to " + "'(string, int)'", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "sum", + .description = "Sum a collection, which is an opaque type.", + .overload_configs = + { + { + .overload_id = "sum(collection)", + .examples = {"sum(my_collection) -> 100"}, + .parameters = {{.name = "collection", + .params = {{.name = "double"}}}}, + .return_type = {.name = "double"}, + }, + }, + }, + .variable_configs = + { + {.name = "my_collection", + .description = "Matching opaque type.", + .type_info = {.name = "collection", + .params = {{.name = "double"}}}}, + }, + .expr = "sum(my_collection) / 3.0", + }, + FunctionConfigTestCase{ + .function_config = + { + .name = "sum", + .description = "Sum a collection, which is an opaque type.", + .overload_configs = + { + { + .overload_id = "sum(collection)", + .examples = {"sum(my_collection) -> 100"}, + .parameters = {{.name = "collection", + .params = {{.name = "int"}}}}, + .return_type = {.name = "double"}, + }, + }, + }, + .variable_configs = + { + {.name = "my_collection", + .description = "Mismatched opaque type.", + .type_info = {.name = "collection", + .params = {{.name = "double"}}}}, + }, + .expr = "sum(my_collection) / 3.0", + .expected_error = "found no matching overload for 'sum' applied to " + "'(collection(double))'", + }, + }}; +} + +INSTANTIATE_TEST_SUITE_P(FunctionConfigTest, FunctionConfigTest, + ::testing::ValuesIn(GetFunctionConfigTestCases())); + +} // namespace +} // namespace cel diff --git a/env/env_yaml.cc b/env/env_yaml.cc new file mode 100644 index 000000000..0035709e9 --- /dev/null +++ b/env/env_yaml.cc @@ -0,0 +1,1026 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_yaml.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "internal/strings.h" +#include "yaml-cpp/emitter.h" +#include "yaml-cpp/emittermanip.h" +#include "yaml-cpp/exceptions.h" +#include "yaml-cpp/mark.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/node/parse.h" +#include "yaml-cpp/null.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel { + +namespace { + +std::string FormatYamlErrorMessage(absl::string_view yaml, + absl::string_view error, + const YAML::Mark& mark) { + if (mark.is_null()) { + return std::string(error); + } + std::string message; + absl::StrAppend(&message, mark.line + 1, ":", mark.column + 1, ": ", error, + "\n|"); + size_t start = mark.pos - mark.column; + size_t end = yaml.find('\n', mark.pos); + if (end == std::string::npos) { + end = yaml.size(); + } + + absl::StrAppend(&message, yaml.substr(start, end - start), "\n|", + std::string(mark.column, ' '), "^"); + + return message; +} + +absl::StatusOr LoadYaml(const std::string& yaml) { + try { + return YAML::Load(yaml); + } catch (YAML::ParserException& e) { + return absl::InvalidArgumentError( + FormatYamlErrorMessage(yaml, e.msg, e.mark)); + } +} + +absl::Status YamlError(absl::string_view yaml, const YAML::Node& node, + absl::string_view error) { + return absl::InvalidArgumentError( + FormatYamlErrorMessage(yaml, error, node.Mark())); +} + +std::string GetString(absl::string_view yaml, const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar()) { + return ""; + } + try { + return node.as(); + } catch (YAML::Exception& e) { + // This should never happen since we already checked that the node is a + // scalar and all scalars can be converted to strings. + return ""; + } +} + +bool IsBinary(const YAML::Node& node) { + return node.Tag() == "!!binary" || node.Tag() == "tag:yaml.org,2002:binary"; +} + +absl::StatusOr GetBinary(absl::string_view yaml, + const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar() || !IsBinary(node)) { + return ""; + } + std::string binary; + // Instead of using the YAML::Binary type, we use absl::Base64Unescape + // because YAML::Binary is lenient to Base64 decoding errors. + if (absl::Base64Unescape(GetString(yaml, node), &binary)) { + return binary; + } else { + return YamlError(yaml, node, + "Node '" + GetString(yaml, node) + + "' is not a valid Base64 encoded binary"); + } +} + +absl::StatusOr GetBool(absl::string_view yaml, absl::string_view key, + const YAML::Node& node) { + if (!node.IsDefined() || !node.IsScalar()) { + return false; + } + try { + return node.as(); + } catch (YAML::Exception& e) { + return YamlError(yaml, node, + "Node '" + std::string(key) + "' is not a boolean"); + } +} + +absl::Status ParseName(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node name = root["name"]; + if (name.IsDefined()) { + if (!name.IsScalar()) { + return YamlError(yaml, name, "Node 'name' is not a string"); + } + config.SetName(GetString(yaml, name)); + } + return absl::OkStatus(); +} + +absl::Status ParseContainerConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node container = root["container"]; + if (container.IsDefined()) { + if (!container.IsScalar()) { + return YamlError(yaml, container, "Node 'container' is not a string"); + } + config.SetContainerConfig({.name = GetString(yaml, container)}); + } + return absl::OkStatus(); +} + +absl::Status ParseExtensionConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node extensions = root["extensions"]; + if (!extensions.IsDefined()) { + return absl::OkStatus(); + } + if (!extensions.IsSequence()) { + return YamlError(yaml, extensions, "Node 'extensions' is not a sequence"); + } + + for (const YAML::Node& extension : extensions) { + if (!extension || !extension.IsMap()) { + return YamlError(yaml, extension, "Extension is not a map"); + } + const YAML::Node name = extension["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Extension name is not a string"); + } + std::string name_str = GetString(yaml, name); + + const YAML::Node version = extension["version"]; + std::string version_str = GetString(yaml, version); + int extension_version; + if (version.IsDefined()) { + bool is_valid_version = false; + if (version.IsScalar()) { + if (version_str == "latest") { + extension_version = Config::ExtensionConfig::kLatest; + is_valid_version = true; + } else { + if (absl::SimpleAtoi(version_str, &extension_version) && + extension_version >= 0) { + is_valid_version = true; + } + } + } + if (!is_valid_version) { + return YamlError( + yaml, version, + absl::StrCat("Extension '", name_str, + "' version is not a valid number or 'latest'")); + } + } else { + extension_version = Config::ExtensionConfig::kLatest; + } + absl::Status add_status = + config.AddExtensionConfig(name_str, extension_version); + if (!add_status.ok()) { + return YamlError(yaml, extension, add_status.message()); + } + } + return absl::OkStatus(); +} + +absl::StatusOr> ParseMacroList( + absl::string_view yaml, const YAML::Node& standard_library, + absl::string_view key) { + absl::flat_hash_set macro_set; + const YAML::Node macros = standard_library[std::string(key)]; + if (!macros.IsDefined()) { + return macro_set; + } + if (!macros.IsSequence()) { + return YamlError(yaml, macros, + absl::StrCat("Node '", key, "' is not a sequence")); + } + for (const YAML::Node& macro : macros) { + if (!macro.IsScalar()) { + return YamlError(yaml, macro, + absl::StrCat("Entry in '", key, "' is not a string")); + } + macro_set.insert(GetString(yaml, macro)); + } + return macro_set; +} + +absl::StatusOr>> +ParseFunctionList(absl::string_view yaml, const YAML::Node& standard_library, + absl::string_view key) { + absl::flat_hash_set> function_set; + const YAML::Node functions = standard_library[std::string(key)]; + if (!functions.IsDefined()) { + return function_set; + } + if (!functions.IsSequence()) { + return YamlError(yaml, functions, + absl::StrCat("Node '", key, "' is not a sequence")); + } + for (const YAML::Node& function : functions) { + if (!function.IsMap()) { + return YamlError(yaml, function, + absl::StrCat("Entry in '", key, "' is not a map")); + } + const YAML::Node name = function["name"]; + if (!name.IsDefined()) { + return YamlError( + yaml, function, + absl::StrCat("Function name in not specified in '", key, "'")); + } + if (!name.IsScalar()) { + return YamlError( + yaml, name, + absl::StrCat("Function name in '", key, "' entry is not a string")); + } + std::string name_str = GetString(yaml, name); + const YAML::Node overloads = function["overloads"]; + if (!overloads.IsDefined()) { + function_set.insert(std::make_pair(name_str, "")); + } else { + if (!overloads.IsSequence()) { + return YamlError( + yaml, overloads, + absl::StrCat("Overloads in '", key, "' entry is not a sequence")); + } + for (const YAML::Node& overload : overloads) { + if (!overload.IsMap()) { + return YamlError( + yaml, overload, + absl::StrCat("Overload in '", key, "' entry is not a map")); + } + const YAML::Node id = overload["id"]; + if (!id || !id.IsScalar()) { + return YamlError( + yaml, id, + absl::StrCat("Overload id in '", key, "' entry is not a string")); + } + function_set.insert(std::make_pair(name_str, GetString(yaml, id))); + } + } + } + return function_set; +} + +absl::Status ParseStandardLibraryConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node standard_library = root["stdlib"]; + if (!standard_library.IsDefined()) { + return absl::OkStatus(); + } + + if (!standard_library.IsMap()) { + return YamlError(yaml, standard_library, + "Standard library config ('stdlib') is not a map"); + } + + Config::StandardLibraryConfig standard_library_config; + + const YAML::Node disable = standard_library["disable"]; + if (disable.IsDefined()) { + if (!disable.IsScalar()) { + return YamlError(yaml, disable, "Node 'disable' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(standard_library_config.disable, + GetBool(yaml, "disable", disable)); + } + + const YAML::Node disable_macros = standard_library["disable_macros"]; + if (disable_macros.IsDefined()) { + if (!disable_macros.IsScalar()) { + return YamlError(yaml, disable_macros, + "Node 'disable_macros' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(standard_library_config.disable_macros, + GetBool(yaml, "disable_macros", disable_macros)); + } + + CEL_ASSIGN_OR_RETURN( + standard_library_config.included_macros, + ParseMacroList(yaml, standard_library, "include_macros")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.excluded_macros, + ParseMacroList(yaml, standard_library, "exclude_macros")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.included_functions, + ParseFunctionList(yaml, standard_library, "include_functions")); + + CEL_ASSIGN_OR_RETURN( + standard_library_config.excluded_functions, + ParseFunctionList(yaml, standard_library, "exclude_functions")); + + return config.SetStandardLibraryConfig(standard_library_config); +} + +absl::StatusOr ParseTypeInfo(const YAML::Node& node, + absl::string_view yaml) { + Config::TypeInfo type_config; + const YAML::Node type_name = node["type_name"]; + if (!type_name.IsDefined()) { + return type_config; + } + if (!type_name || !type_name.IsScalar()) { + return YamlError(yaml, type_name, "Node 'type_name' is not a string"); + } + type_config.name = GetString(yaml, type_name); + + const YAML::Node is_type_param = node["is_type_param"]; + if (is_type_param.IsDefined()) { + if (!is_type_param.IsScalar()) { + return YamlError(yaml, is_type_param, + "Node 'is_type_param' is not a boolean"); + } + CEL_ASSIGN_OR_RETURN(type_config.is_type_param, + GetBool(yaml, "is_type_param", is_type_param)); + } + + const YAML::Node params = node["params"]; + if (!params.IsDefined()) { + return type_config; + } + if (!params.IsSequence()) { + return YamlError(yaml, params, "Node 'params' is not a sequence"); + } + for (const YAML::Node& param : params) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param_config, + ParseTypeInfo(param, yaml)); + type_config.params.push_back(param_config); + } + + return type_config; +} + +bool CompareTypeInfo(const Config::TypeInfo& a, const Config::TypeInfo& b) { + if (a.name != b.name) { + return a.name < b.name; + } + if (a.params.size() != b.params.size()) { + return a.params.size() < b.params.size(); + } + for (size_t i = 0; i < a.params.size(); ++i) { + if (CompareTypeInfo(a.params[i], b.params[i])) { + return true; + } + if (CompareTypeInfo(b.params[i], a.params[i])) { + return false; + } + } + return false; // They are equal +} + +ConstantKindCase GetConstantKindCase(absl::string_view type_name) { + static const auto kTypeNameToConstantKindCase = + absl::NoDestructor>({ + {"null", ConstantKindCase::kNull}, + {"bool", ConstantKindCase::kBool}, + {"int", ConstantKindCase::kInt}, + {"uint", ConstantKindCase::kUint}, + {"double", ConstantKindCase::kDouble}, + {"string", ConstantKindCase::kString}, + {"bytes", ConstantKindCase::kBytes}, + {"duration", ConstantKindCase::kDuration}, + {"timestamp", ConstantKindCase::kTimestamp}, + }); + if (auto it = kTypeNameToConstantKindCase->find(type_name); + it != kTypeNameToConstantKindCase->end()) { + return it->second; + } + return ConstantKindCase::kUnspecified; +} + +absl::StatusOr ParseConstantValue(absl::string_view yaml, + const YAML::Node& node, + ConstantKindCase constant_kind_case, + absl::string_view value) { + switch (constant_kind_case) { + case ConstantKindCase::kNull: + if (!value.empty()) { + return YamlError(yaml, node, "Failed to parse null constant"); + } + return Constant(nullptr); + case ConstantKindCase::kBool: + if (absl::EqualsIgnoreCase(value, "true")) { + return Constant(true); + } else if (absl::EqualsIgnoreCase(value, "false")) { + return Constant(false); + } else { + return YamlError(yaml, node, "Failed to parse bool constant"); + } + case ConstantKindCase::kInt: + int64_t int_value; + if (!absl::SimpleAtoi(value, &int_value)) { + return YamlError(yaml, node, "Failed to parse int constant"); + } + return Constant(int_value); + case ConstantKindCase::kUint: + uint64_t uint_value; + if (absl::EndsWith(value, "u")) { + value = value.substr(0, value.size() - 1); + } + if (!absl::SimpleAtoi(value, &uint_value)) { + return YamlError(yaml, node, "Failed to parse uint constant"); + } + return Constant(uint_value); + case ConstantKindCase::kDouble: + double double_value; + if (!absl::SimpleAtod(value, &double_value)) { + return YamlError(yaml, node, "Failed to parse double constant"); + } + return Constant(double_value); + case ConstantKindCase::kBytes: { + if (!IsBinary(node)) { + absl::StatusOr bytes_literal = + internal::ParseBytesLiteral(value); + if (bytes_literal.ok()) { + return Constant(BytesConstant(*bytes_literal)); + } + } + return Constant(BytesConstant(value)); + } + case ConstantKindCase::kString: + return Constant(StringConstant(value)); + case ConstantKindCase::kDuration: { + // Duration is deprecated as a builtin type, but still supported for + // compatibility. + absl::Duration duration_value; + if (!absl::ParseDuration(value, &duration_value)) { + return YamlError(yaml, node, "Failed to parse duration constant"); + } + return Constant(duration_value); + } + case ConstantKindCase::kTimestamp: { + // Timestamp is deprecated as a builtin type, but still supported for + // compatibility. + absl::Time timestamp_value; + std::string error; + // Format: YYYY-MM-DDThh:mm:ssZ + if (!absl::ParseTime("%Y-%m-%d%ET%H:%M:%E*SZ", value, ×tamp_value, + &error)) { + return YamlError( + yaml, node, + absl::StrCat("Failed to parse timestamp constant: ", error, + " supported format: YYYY-MM-DDThh:mm:ssZ")); + } + return Constant(timestamp_value); + } + default: + // This should never happen. + return YamlError(yaml, node, "Constant type is not supported"); + } +} + +absl::Status ParseVariableConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node variables = root["variables"]; + if (!variables.IsDefined()) { + return absl::OkStatus(); + } + if (!variables.IsSequence()) { + return YamlError(yaml, variables, "Node 'variables' is not a sequence"); + } + + for (const YAML::Node& variable : variables) { + Config::VariableConfig variable_config; + if (!variable || !variable.IsMap()) { + return YamlError(yaml, variable, "Variable is not a map"); + } + const YAML::Node name = variable["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Variable name is not a string"); + } + variable_config.name = GetString(yaml, name); + const YAML::Node description = variable["description"]; + if (description.IsDefined()) { + if (!description.IsScalar()) { + return YamlError(yaml, description, + "Variable description is not a string"); + } + variable_config.description = GetString(yaml, description); + } + + CEL_ASSIGN_OR_RETURN(auto type_info, ParseTypeInfo(variable, yaml)); + ConstantKindCase constant_kind_case = GetConstantKindCase(type_info.name); + std::string value_str; + YAML::Node value = variable["value"]; + if (value.IsDefined()) { + if (constant_kind_case == ConstantKindCase::kUnspecified) { + return YamlError(yaml, value, + absl::StrCat("Constant type '", type_info.name, + "' is not supported")); + } + if (!value.IsScalar()) { + return YamlError(yaml, value, "Variable value is not a scalar"); + } + if (IsBinary(value)) { + CEL_ASSIGN_OR_RETURN(value_str, GetBinary(yaml, value)); + } else { + value_str = GetString(yaml, value); + } + } + + variable_config.type_info = type_info; + + if (constant_kind_case != ConstantKindCase::kUnspecified) { + CEL_ASSIGN_OR_RETURN( + variable_config.value, + ParseConstantValue(yaml, value, constant_kind_case, value_str)); + } + + CEL_RETURN_IF_ERROR(config.AddVariableConfig(variable_config)); + } + return absl::OkStatus(); +} + +absl::StatusOr ParseFunctionOverloadConfig( + absl::string_view yaml, const YAML::Node& overload) { + Config::FunctionOverloadConfig overload_config; + if (!overload || !overload.IsMap()) { + return YamlError(yaml, overload, "Function overload is not a map"); + } + const YAML::Node id = overload["id"]; + if (id.IsDefined()) { + if (!id.IsScalar()) { + return YamlError(yaml, id, "Function overload id is not a string"); + } + overload_config.overload_id = GetString(yaml, id); + } + const YAML::Node examples = overload["examples"]; + if (examples.IsDefined()) { + if (!examples.IsSequence()) { + return YamlError(yaml, examples, + "Function overload examples is not a sequence"); + } + for (const YAML::Node& example : examples) { + if (!example.IsScalar()) { + return YamlError(yaml, example, + "Function overload example is not a string"); + } + overload_config.examples.push_back(GetString(yaml, example)); + } + } + + const YAML::Node target = overload["target"]; + if (target.IsDefined()) { + if (!target.IsMap()) { + return YamlError(yaml, target, "Function overload target is not a map"); + } + CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, + ParseTypeInfo(target, yaml)); + overload_config.is_member_function = true; + overload_config.parameters.push_back(type_info); + } + + const YAML::Node args = overload["args"]; + if (args.IsDefined()) { + if (!args.IsSequence()) { + return YamlError(yaml, args, "Function overload args is not a sequence"); + } + for (const YAML::Node& arg : args) { + if (!arg.IsMap()) { + return YamlError(yaml, arg, "Function overload arg is not a map"); + } + CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, + ParseTypeInfo(arg, yaml)); + overload_config.parameters.push_back(type_info); + } + } + + const YAML::Node return_type = overload["return"]; + if (return_type.IsDefined()) { + if (!return_type.IsMap()) { + return YamlError(yaml, return_type, + "Function overload return type is not a map"); + } + CEL_ASSIGN_OR_RETURN(overload_config.return_type, + ParseTypeInfo(return_type, yaml)); + } + return overload_config; +} + +absl::Status ParseFunctionConfigs(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node functions = root["functions"]; + if (!functions.IsDefined()) { + return absl::OkStatus(); + } + if (!functions.IsSequence()) { + return YamlError(yaml, functions, "Node 'functions' is not a sequence"); + } + + for (const YAML::Node& function : functions) { + Config::FunctionConfig function_config; + if (!function || !function.IsMap()) { + return YamlError(yaml, function, "Function is not a map"); + } + const YAML::Node name = function["name"]; + if (!name || !name.IsScalar()) { + return YamlError(yaml, name, "Function name is not a string"); + } + function_config.name = GetString(yaml, name); + const YAML::Node description = function["description"]; + if (description.IsDefined()) { + if (!description.IsScalar()) { + return YamlError(yaml, description, + "Function description is not a string"); + } + function_config.description = GetString(yaml, description); + } + const YAML::Node overloads = function["overloads"]; + if (overloads.IsDefined()) { + if (!overloads.IsSequence()) { + return YamlError(yaml, overloads, + "Function 'overloads' item is not a sequence"); + } + + for (const YAML::Node& overload : overloads) { + CEL_ASSIGN_OR_RETURN(Config::FunctionOverloadConfig overload_config, + ParseFunctionOverloadConfig(yaml, overload)); + function_config.overload_configs.push_back(std::move(overload_config)); + } + } + + CEL_RETURN_IF_ERROR(config.AddFunctionConfig(function_config)); + } + return absl::OkStatus(); +} + +void EmitContainerConfig(const Config& env_config, YAML::Emitter& out) { + const auto& container_config = env_config.GetContainerConfig(); + if (container_config.IsEmpty()) { + return; + } + + out << YAML::Key << "container"; + out << YAML::Value << YAML::DoubleQuoted << container_config.name; +} + +void EmitExtensionConfigs(const Config& env_config, YAML::Emitter& out) { + if (env_config.GetExtensionConfigs().empty()) { + return; + } + + // Sort the extensions to make the output deterministic. + std::vector sorted_extensions = + env_config.GetExtensionConfigs(); + absl::c_sort(sorted_extensions, [](const Config::ExtensionConfig& a, + const Config::ExtensionConfig& b) { + return a.name < b.name; + }); + out << YAML::Key << "extensions"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::ExtensionConfig& extension_config : sorted_extensions) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << extension_config.name; + if (extension_config.version != Config::ExtensionConfig::kLatest) { + out << YAML::Key << "version"; + out << YAML::Value << extension_config.version; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitMacroList(YAML::Emitter& out, absl::string_view key, + const absl::flat_hash_set& macros) { + if (macros.empty()) { + return; + } + out << YAML::Key << std::string(key); + out << YAML::Value << YAML::BeginSeq; + std::vector sorted_macros(macros.begin(), macros.end()); + absl::c_sort(sorted_macros); + for (const std::string& macro : sorted_macros) { + out << YAML::Value << YAML::DoubleQuoted << macro; + } + out << YAML::EndSeq; +} + +void EmitFunctionList( + YAML::Emitter& out, absl::string_view key, + const absl::flat_hash_set>& functions) { + if (functions.empty()) { + return; + } + + // Build a map from function name to a vector of overload ids. + // Using std::map ensures function names are sorted. + std::map> function_overloads; + for (const auto& pair : functions) { + function_overloads[pair.first].push_back(pair.second); + } + + out << YAML::Key << std::string(key) << YAML::Value << YAML::BeginSeq; + for (auto const& [name, overloads] : function_overloads) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << name; + + // If the only overload is the empty string, it signifies that all overloads + // of the function are included/excluded. In this case, we don't emit the + // "overloads" key. Otherwise, emit the specific overloads. + if (!(overloads.size() == 1 && overloads[0].empty())) { + // Sort overloads for deterministic output. + std::vector sorted_overloads = overloads; + absl::c_sort(sorted_overloads); + + out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; + for (const std::string& overload : sorted_overloads) { + out << YAML::BeginMap; + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload; + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitStandardLibraryConfig(const Config& env_config, YAML::Emitter& out) { + const Config::StandardLibraryConfig& standard_library_config = + env_config.GetStandardLibraryConfig(); + if (standard_library_config.IsEmpty()) { + return; + } + + out << YAML::Key << "stdlib" << YAML::Value << YAML::BeginMap; + if (standard_library_config.disable) { + out << YAML::Key << "disable" << YAML::Value << true; + } + if (standard_library_config.disable_macros) { + out << YAML::Key << "disable_macros" << YAML::Value << true; + } + EmitMacroList(out, "include_macros", standard_library_config.included_macros); + EmitMacroList(out, "exclude_macros", standard_library_config.excluded_macros); + EmitFunctionList(out, "include_functions", + standard_library_config.included_functions); + EmitFunctionList(out, "exclude_functions", + standard_library_config.excluded_functions); + out << YAML::EndMap; +} + +void EmitTypeInfo(const Config::TypeInfo& type_info, YAML::Emitter& out) { + // Note: the map is already started when this is called, so we don't emit + // BeginMap here or EndMap at the end. + out << YAML::Key << "type_name"; + out << YAML::Value << YAML::DoubleQuoted << type_info.name; + if (type_info.is_type_param) { + out << YAML::Key << "is_type_param" << YAML::Value << true; + } + if (!type_info.params.empty()) { + out << YAML::Key << "params" << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& param : type_info.params) { + out << YAML::BeginMap; + EmitTypeInfo(param, out); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } +} + +void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { + const auto& variable_configs = env_config.GetVariableConfigs(); + if (variable_configs.empty()) { + return; + } + + // Sort variable_configs by name to ensure deterministic output. + std::vector sorted_variable_configs = + variable_configs; + absl::c_sort(sorted_variable_configs, + [](const Config::VariableConfig& a, + const Config::VariableConfig& b) { return a.name < b.name; }); + + out << YAML::Key << "variables"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::VariableConfig& variable_config : + sorted_variable_configs) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << variable_config.name; + if (!variable_config.description.empty()) { + out << YAML::Key << "description"; + out << YAML::Value << YAML::DoubleQuoted << variable_config.description; + } + EmitTypeInfo(variable_config.type_info, out); + if (variable_config.value.has_value()) { + const Constant& constant = variable_config.value; + switch (constant.kind_case()) { + case ConstantKindCase::kUnspecified: + case ConstantKindCase::kNull: + break; + case ConstantKindCase::kBool: + out << YAML::Key << "value" << YAML::Value << constant.bool_value(); + break; + case ConstantKindCase::kInt: + out << YAML::Key << "value" << YAML::Value << constant.int_value(); + break; + case ConstantKindCase::kUint: + out << YAML::Key << "value" << YAML::Value << constant.uint_value(); + break; + case ConstantKindCase::kDouble: + out << YAML::Key << "value" << YAML::Value << constant.double_value(); + break; + case ConstantKindCase::kBytes: { + out << YAML::Key << "value"; + const std::string& bytes_value = constant.bytes_value(); + std::string hex_escaped = "b\""; + for (unsigned char byte : bytes_value) { + absl::StrAppend(&hex_escaped, "\\x"); + absl::StrAppendFormat(&hex_escaped, "%02x", byte); + } + absl::StrAppend(&hex_escaped, "\""); + out << YAML::Value << hex_escaped; + break; + } + case ConstantKindCase::kString: + out << YAML::Key << "value"; + out << YAML::Value << YAML::DoubleQuoted << constant.string_value(); + break; + case ConstantKindCase::kDuration: + out << YAML::Key << "value" << YAML::Value; + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + out << absl::FormatDuration(constant.duration_value()); + break; + case ConstantKindCase::kTimestamp: + out << YAML::Key << "value" << YAML::Value; + out << absl::FormatTime( + "%4Y-%2m-%2d%ET%2H:%2M:%E*SZ", + // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) + constant.timestamp_value(), absl::UTCTimeZone()); + break; + } + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} + +void EmitFunctionOverloadConfig( + const Config::FunctionOverloadConfig& overload_config, YAML::Emitter& out) { + out << YAML::BeginMap; + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; + if (overload_config.is_member_function) { + out << YAML::Key << "target" << YAML::Value; + out << YAML::BeginMap; + if (overload_config.parameters.empty()) { + // This should never happen, but if it does, emit a dynamic type. + EmitTypeInfo({.name = "dyn"}, out); + } else { + EmitTypeInfo(overload_config.parameters[0], out); + } + out << YAML::EndMap; + if (overload_config.parameters.size() > 1) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (size_t i = 1; i < overload_config.parameters.size(); ++i) { + out << YAML::BeginMap; + EmitTypeInfo(overload_config.parameters[i], out); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } else { + if (!overload_config.parameters.empty()) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& parameter : overload_config.parameters) { + out << YAML::BeginMap; + EmitTypeInfo(parameter, out); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } + out << YAML::Key << "return"; + out << YAML::Value << YAML::BeginMap; + EmitTypeInfo(overload_config.return_type, out); + out << YAML::EndMap; + + out << YAML::EndMap; +} + +void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out) { + const std::vector& function_configs = + env_config.GetFunctionConfigs(); + if (function_configs.empty()) { + return; + } + + // Sort function_configs by name to ensure deterministic output. + std::vector sorted_function_configs = + function_configs; + absl::c_sort(sorted_function_configs, + [](const Config::FunctionConfig& a, + const Config::FunctionConfig& b) { return a.name < b.name; }); + + out << YAML::Key << "functions"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::FunctionConfig& function_config : + sorted_function_configs) { + out << YAML::BeginMap; + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << function_config.name; + if (!function_config.description.empty()) { + out << YAML::Key << "description"; + out << YAML::Value << YAML::DoubleQuoted << function_config.description; + } + if (!function_config.overload_configs.empty()) { + // Sort overloads for deterministic output. + std::vector sorted_overloads = + function_config.overload_configs; + absl::c_sort(sorted_overloads, + [](const Config::FunctionOverloadConfig& a, + const Config::FunctionOverloadConfig& b) { + for (size_t i = 0; i < a.parameters.size(); ++i) { + // Order like this: foo(a), foo(a, b) + if (i >= b.parameters.size()) { + return false; + } + if (CompareTypeInfo(a.parameters[i], b.parameters[i])) { + return true; + } + if (CompareTypeInfo(b.parameters[i], a.parameters[i])) { + return false; + } + } + return false; + }); + + out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; + for (const Config::FunctionOverloadConfig& overload_config : + sorted_overloads) { + EmitFunctionOverloadConfig(overload_config, out); + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } + out << YAML::EndSeq; +} +} // namespace + +absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { + Config config; + CEL_ASSIGN_OR_RETURN(YAML::Node root, LoadYaml(yaml)); + CEL_RETURN_IF_ERROR(ParseName(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseContainerConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseExtensionConfigs(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseStandardLibraryConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseVariableConfigs(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseFunctionConfigs(config, yaml, root)); + return config; +} + +void EnvConfigToYaml(const Config& env_config, std::ostream& os) { + YAML::Emitter out(os); + out.SetIndent(2); + out << YAML::BeginMap; + if (!env_config.GetName().empty()) { + out << YAML::Key << "name"; + out << YAML::Value << YAML::DoubleQuoted << env_config.GetName(); + } + EmitContainerConfig(env_config, out); + EmitExtensionConfigs(env_config, out); + EmitStandardLibraryConfig(env_config, out); + EmitVariableConfigs(env_config, out); + EmitFunctionConfigs(env_config, out); + out << YAML::EndMap; +} + +} // namespace cel diff --git a/env/env_yaml.h b/env/env_yaml.h new file mode 100644 index 000000000..c96b45933 --- /dev/null +++ b/env/env_yaml.h @@ -0,0 +1,39 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ +#define THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "env/config.h" + +namespace cel { + +// EnvConfigFromYaml creates an environment configuration from a YAML string. +// +// To ensure safety, only pass trusted YAML input. yaml-cpp has some fuzz +// coverage, but its security model is unclear. Additionally, callers should be +// aware that improper CEL configuration can lead to unsafe or unpredictably +// expensive expressions. +absl::StatusOr EnvConfigFromYaml(const std::string& yaml); + +// EnvConfigToYaml serializes an environment configuration as a YAML string. +void EnvConfigToYaml(const Config& env_config, std::ostream& os); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_ENV_YAML_H_ diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc new file mode 100644 index 000000000..828a39b48 --- /dev/null +++ b/env/env_yaml_test.cc @@ -0,0 +1,1467 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/env_yaml.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::StatusIs; +using ::testing::AllOf; +using ::testing::ElementsAreArray; +using ::testing::Field; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +TEST(EnvYamlTest, ParseContainerConfig) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + container: "test.container" + )yaml")); + + EXPECT_THAT(config.GetContainerConfig(), + Field(&Config::ContainerConfig::name, "test.container")); +} + +TEST(EnvYamlTest, ParseExtensionConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + extensions: + - name: "math" + version: latest + - name: "optional" + version: 2 + - name: "strings" + )yaml")); + + EXPECT_THAT(config.GetExtensionConfigs(), + UnorderedElementsAre( + AllOf(Field(&Config::ExtensionConfig::name, "math"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)), + AllOf(Field(&Config::ExtensionConfig::name, "optional"), + Field(&Config::ExtensionConfig::version, 2)), + AllOf(Field(&Config::ExtensionConfig::name, "strings"), + Field(&Config::ExtensionConfig::version, + Config::ExtensionConfig::kLatest)))); +} + +TEST(EnvYamlTest, DefaultExtensionConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + )yaml")); + + EXPECT_THAT(config.GetExtensionConfigs(), IsEmpty()); +} + +TEST(EnvYamlTest, ParseStdlibConfig_ExclusionStyle) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + stdlib: + disable: true + disable_macros: true + exclude_macros: + - map + - filter + exclude_functions: + - name: "_+_" + overloads: + - id: add_bytes + - id: add_list + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml")); + + const auto& stdlib_config = config.GetStandardLibraryConfig(); + EXPECT_TRUE(stdlib_config.disable); + EXPECT_TRUE(stdlib_config.disable_macros); + EXPECT_THAT(stdlib_config.excluded_macros, + UnorderedElementsAre("map", "filter")); + EXPECT_THAT(stdlib_config.included_macros, IsEmpty()); + EXPECT_THAT( + stdlib_config.excluded_functions, + UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("timestamp", "string_to_timestamp"))) + << " Actual stdlib config: " << stdlib_config; +} + +TEST(EnvYamlTest, ParseStdlibConfig_InclusionStyle) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + stdlib: + include_macros: + - map + - filter + include_functions: + - name: "_+_" + overloads: + - id: add_bytes + - id: add_list + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml")); + + const auto& stdlib_config = config.GetStandardLibraryConfig(); + EXPECT_THAT(stdlib_config.included_macros, + UnorderedElementsAre("map", "filter")); + EXPECT_THAT( + stdlib_config.included_functions, + UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("timestamp", "string_to_timestamp"))) + << " Actual stdlib config: " << stdlib_config; +} + +TEST(EnvYamlTest, ParseVariableConfigs) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "msg" + type_name: "google.expr.proto3.test.TestAllTypes" + description: >- + msg represents all possible type permutation which + CEL understands from a proto perspective + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "msg"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "google.expr.proto3.test.TestAllTypes"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, IsEmpty()); + EXPECT_EQ(variable_config.description, + "msg represents all possible type permutation which CEL " + "understands from a proto perspective"); +} + +TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "dict" + type_name: "map" + params: + - type_name: "string" + - type_name: "A" + is_type_param: true + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "dict"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "map"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, SizeIs(2)); + EXPECT_EQ(type_info.params[0].name, "string"); + EXPECT_FALSE(type_info.params[0].is_type_param); + EXPECT_THAT(type_info.params[0].params, IsEmpty()); + EXPECT_EQ(type_info.params[1].name, "A"); + EXPECT_TRUE(type_info.params[1].is_type_param); + EXPECT_THAT(type_info.params[1].params, IsEmpty()); +} + +struct ParseConstantTestCase { + std::string type_name; + std::string value; + std::string expected_error; // Empty if no error. + Constant expected_constant; +}; + +class EnvYamlParseConstantTest + : public testing::TestWithParam {}; + +TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { + const ParseConstantTestCase& param = GetParam(); + const std::string yaml = absl::StrFormat( + R"yaml( + variables: + - name: "const" + type_name: "%s" + value: %s + )yaml", + param.type_name, param.value); + absl::StatusOr status_or_config = EnvConfigFromYaml(yaml); + if (!param.expected_error.empty()) { + EXPECT_THAT(status_or_config, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + return; + } + ASSERT_OK_AND_ASSIGN(Config config, status_or_config); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "const"); + EXPECT_EQ(variable_config.type_info.name, param.type_name); + EXPECT_EQ(variable_config.value, param.expected_constant); +} + +std::vector GetParseConstantTestCases() { + return { + ParseConstantTestCase{ + .type_name = "null", + .value = "\"\"", + .expected_constant = Constant(nullptr), + }, + ParseConstantTestCase{ + .type_name = "null", + .value = "anything", + .expected_error = "Failed to parse null constant", + }, + ParseConstantTestCase{ + .type_name = "bool", + .value = "TRUE", + .expected_constant = Constant(true), + }, + ParseConstantTestCase{ + .type_name = "bool", + .value = "false", + .expected_constant = Constant(false), + }, + ParseConstantTestCase{ + .type_name = "bool", + .value = "yes", + .expected_error = "Failed to parse bool constant", + }, + ParseConstantTestCase{ + .type_name = "int", + .value = "42", + .expected_constant = Constant(int64_t{42}), + }, + ParseConstantTestCase{ + .type_name = "int", + .value = "41.999", + .expected_error = "Failed to parse int constant", + }, + ParseConstantTestCase{ + .type_name = "uint", + .value = "42", + .expected_constant = Constant(uint64_t{42}), + }, + ParseConstantTestCase{ + .type_name = "uint", + .value = "42u", + .expected_constant = Constant(uint64_t{42}), + }, + ParseConstantTestCase{ + .type_name = "uint", + .value = "-1", + .expected_error = "Failed to parse uint constant", + }, + ParseConstantTestCase{ + .type_name = "double", + .value = "42.42", + .expected_constant = Constant(42.42), + }, + ParseConstantTestCase{ + .type_name = "double", + .value = "abc", + .expected_error = "Failed to parse double constant", + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "abc", + .expected_constant = Constant(BytesConstant("abc")), + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "b\"\\xFF\\x00\\x01\"", + .expected_constant = + Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "!!binary /wAB", + .expected_constant = + Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "!!binary YWJj=", + .expected_error = "Node 'YWJj=' is not a valid Base64 encoded binary", + }, + ParseConstantTestCase{ + .type_name = "bytes", + .value = "abc", + .expected_constant = Constant(BytesConstant("abc")), + }, + ParseConstantTestCase{ + .type_name = "string", + .value = "abc", + .expected_constant = Constant(StringConstant("abc")), + }, + ParseConstantTestCase{ + .type_name = "string", + .value = "\"\\\"abc\\\"\"", + .expected_constant = Constant(StringConstant("\"abc\"")), + }, + ParseConstantTestCase{ + .type_name = "duration", + .value = "1s", + .expected_constant = Constant(absl::Seconds(1)), + }, + ParseConstantTestCase{ + .type_name = "duration", + .value = "abc", + .expected_error = "Failed to parse duration constant", + }, + ParseConstantTestCase{ + .type_name = "timestamp", + .value = "2023-01-01T00:00:00Z", + .expected_constant = Constant(absl::FromUnixSeconds(1672531200)), + }, + ParseConstantTestCase{ + .type_name = "timestamp", + .value = "abc", + .expected_error = "Failed to parse timestamp constant", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlParseConstantTest, EnvYamlParseConstantTest, + ::testing::ValuesIn(GetParseConstantTestCases())); + +struct ParseFunctionTestCase { + std::string yaml; + Config::FunctionConfig expected_function_config; +}; + +class EnvYamlParseFunctionTest + : public testing::TestWithParam {}; + +void ExpectTypeInfoEqual(const Config::TypeInfo& actual, + const Config::TypeInfo& expected) { + EXPECT_EQ(actual.name, expected.name); + EXPECT_EQ(actual.is_type_param, expected.is_type_param); + ASSERT_THAT(actual.params, SizeIs(expected.params.size())); + for (size_t i = 0; i < expected.params.size(); ++i) { + ExpectTypeInfoEqual(actual.params[i], expected.params[i]); + } +} + +TEST_P(EnvYamlParseFunctionTest, EnvYamlParseFunction) { + const ParseFunctionTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(param.yaml)); + + ASSERT_THAT(config.GetFunctionConfigs(), SizeIs(1)); + const Config::FunctionConfig& function_config = + config.GetFunctionConfigs()[0]; + const Config::FunctionConfig& expected = param.expected_function_config; + + EXPECT_EQ(function_config.name, expected.name); + EXPECT_EQ(function_config.description, expected.description); + + ASSERT_THAT(function_config.overload_configs, + SizeIs(expected.overload_configs.size())); + + for (size_t i = 0; i < expected.overload_configs.size(); ++i) { + const auto& actual_overload = function_config.overload_configs[i]; + const auto& expected_overload = expected.overload_configs[i]; + + EXPECT_EQ(actual_overload.overload_id, expected_overload.overload_id); + EXPECT_THAT(actual_overload.examples, + ElementsAreArray(expected_overload.examples)); + EXPECT_EQ(actual_overload.is_member_function, + expected_overload.is_member_function); + + ASSERT_THAT(actual_overload.parameters, + SizeIs(expected_overload.parameters.size())); + for (size_t j = 0; j < expected_overload.parameters.size(); ++j) { + ExpectTypeInfoEqual(actual_overload.parameters[j], + expected_overload.parameters[j]); + } + + ExpectTypeInfoEqual(actual_overload.return_type, + expected_overload.return_type); + } +} + +std::vector GetParseFunctionTestCases() { + return { + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "isEmpty" + description: |- + determines whether a list is empty, + or a string has no characters + overloads: + - id: "wrapper_string_isEmpty" + examples: + - "''.isEmpty() // true" + target: + type_name: "google.protobuf.StringValue" + return: + type_name: "bool" + - id: "list_isEmpty" + examples: + - "[].isEmpty() // true" + - "[1].isEmpty() // false" + target: + type_name: "list" + params: + - type_name: "T" + is_type_param: true + return: + type_name: "bool" + )yaml", + .expected_function_config = + { + .name = "isEmpty", + .description = "determines whether a list is empty,\nor a " + "string has no characters", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .overload_id = "wrapper_string_isEmpty", + .examples = {"''.isEmpty() // true"}, + .is_member_function = true, + .parameters = + {{.name = "google.protobuf.StringValue"}}, + .return_type = {.name = "bool"}, + }, + Config::FunctionOverloadConfig{ + .overload_id = "list_isEmpty", + .examples = {"[].isEmpty() // true", + "[1].isEmpty() // false"}, + .is_member_function = true, + .parameters = {{.name = "list", + .params = {{.name = "T", + .is_type_param = + true}}}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "contains" + overloads: + - id: "global_contains" + examples: + - "contains([1, 2, 3], 2) // true" + args: + - type_name: "list" + params: + - type_name: "T" + is_type_param: true + - type_name: "T" + is_type_param: true + return: + type_name: "bool" + )yaml", + .expected_function_config = + { + .name = "contains", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .overload_id = "global_contains", + .examples = {"contains([1, 2, 3], 2) // true"}, + .is_member_function = false, + .parameters = + {{.name = "list", + .params = {{.name = "T", + .is_type_param = true}}}, + {.name = "T", .is_type_param = true}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlParseFunctionTest, EnvYamlParseFunctionTest, + ::testing::ValuesIn(GetParseFunctionTestCases())); + +struct ParseTestCase { + std::string yaml; + std::string expected_error; +}; + +class EnvYamlParseTest : public testing::TestWithParam {}; + +TEST_P(EnvYamlParseTest, EnvYamlSyntaxError) { + const ParseTestCase& param = GetParam(); + absl::StatusOr config = EnvConfigFromYaml(param.yaml); + EXPECT_THAT(config, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); +} + +INSTANTIATE_TEST_SUITE_P( + EnvYamlParseTest, EnvYamlParseTest, + ::testing::Values( + ParseTestCase{ + .yaml = R"yaml( + name: + - error: "error" + )yaml", + .expected_error = "3:19: Node 'name' is not a string\n" + "| - error: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + - error: "error" + )yaml", + .expected_error = "3:19: Node 'container' is not a string\n" + "| - error: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + -name: "optional" + - name: "other" + )yaml", + .expected_error = "5:21: end of map not found\n" + "| - name: \"other\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: "bar" + )yaml", + .expected_error = "2:27: Node 'extensions' is not a sequence\n" + "| extensions: \"bar\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: + - something: "bar" + )yaml", + .expected_error = "4:19: Extension name is not a string\n" + "| - something: \"bar\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: last + )yaml", + .expected_error = "4:28: Extension 'math' version is not a valid " + "number or 'latest'\n" + "| version: last\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: -15 + )yaml", + .expected_error = "4:28: Extension 'math' version is not a valid " + "number or 'latest'\n" + "| version: -15\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + extensions: + - name: "math" + version: 1 + - name: "math" + version: 2 + )yaml", + .expected_error = "5:19: Extension 'math' version 1 is already " + "included. Cannot also include version 2\n" + "| - name: \"math\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: "error" + )yaml", + .expected_error = "2:23: Standard library config ('stdlib') " + "is not a map\n" + "| stdlib: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + disable: "error" + )yaml", + .expected_error = "3:26: Node 'disable' is not a boolean\n" + "| disable: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + disable_macros: "error" + )yaml", + .expected_error = "3:33: Node 'disable_macros' is not a boolean\n" + "| disable_macros: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + exclude_macros: "error" + )yaml", + .expected_error = "3:33: Node 'exclude_macros' is not a sequence\n" + "| exclude_macros: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + exclude_macros: + - foo: "error" + )yaml", + .expected_error = "4:19: Entry in 'exclude_macros' " + "is not a string\n" + "| - foo: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: "error" + )yaml", + .expected_error = "3:36: Node 'include_functions' " + "is not a sequence\n" + "| include_functions: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - "error" + )yaml", + .expected_error = "4:19: Entry in 'include_functions' " + "is not a map\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - foo: "error" + )yaml", + .expected_error = "4:19: Function name in not specified in " + "'include_functions'\n" + "| - foo: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: "error" + )yaml", + .expected_error = "5:30: Overloads in 'include_functions' entry " + "is not a sequence\n" + "| overloads: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: + - foo_string + )yaml", + .expected_error = "6:21: Overload in 'include_functions' entry " + "is not a map\n" + "| - foo_string\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + stdlib: + include_functions: + - name: "foo" + overloads: + - id: + - foo_int64 + )yaml", + .expected_error = "7:21: Overload id in 'include_functions' entry " + "is not a string\n" + "| - foo_int64\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: + - type_name: "opaque" + )yaml", + .expected_error = "4:19: Variable name is not a string\n" + "| - type_name: \"opaque\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: + - params: + )yaml", + .expected_error = "5:21: Node 'type_name' is not a string\n" + "| - params:\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "opaque" + params: + - type_name: "int" + - type_name: "A" + is_type_param: maybe + )yaml", + .expected_error = "8:38: Node 'is_type_param' is not a boolean\n" + "| is_type_param: maybe\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "uint" + value: -1 + )yaml", + .expected_error = "5:26: Failed to parse uint constant\n" + "| value: -1\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: many + )yaml", + .expected_error = "2:26: Node 'functions' is not a sequence\n" + "| functions: many\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: + - overloads: + )yaml", + .expected_error = "4:19: Function name is not a string\n" + "| - overloads:\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: "error" + )yaml", + .expected_error = "4:30: Function 'overloads' item " + "is not a sequence\n" + "| overloads: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: + - "error" + )yaml", + .expected_error = "6:25: Function overload id is not a string\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + target: + - "error" + )yaml", + .expected_error = "7:25: Function overload target is not a map\n" + "| - \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + target: + type_name: "Foo" + params: + - type_name: + - is_type_param: true + )yaml", + .expected_error = "10:31: Node 'type_name' is not a string\n" + "| " + "- is_type_param: true\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + args: "a bunch" + )yaml", + .expected_error = "6:29: Function overload args is not a sequence\n" + "| args: \"a bunch\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + return: "to sender" + )yaml", + .expected_error = "6:31: Function overload return type" + " is not a map\n" + "| return: \"to sender\"\n" + "| ^", + })); + +std::string Unindent(std::string_view yaml) { + std::vector lines = absl::StrSplit(yaml, '\n'); + int indent = -1; + std::vector unindented_lines; + for (auto& line : lines) { + std::size_t pos = line.find_first_not_of(" \t"); + if (pos == std::string::npos) { + // Skip blank lines. + continue; + } + if (indent == -1) { + indent = pos; + } + if (pos >= indent) { + unindented_lines.push_back(line.substr(indent)); + } else { + unindented_lines.push_back(line); + } + } + return absl::StrJoin(unindented_lines, "\n"); +} + +struct ExportTestCase { + absl::StatusOr config; + std::string expected_yaml; +}; + +class EnvYamlExportTest : public testing::TestWithParam {}; + +TEST_P(EnvYamlExportTest, EnvYamlExport) { + const ExportTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config config, param.config); + std::stringstream ss; + EnvConfigToYaml(config, ss); + std::string yaml_output = Unindent(ss.str()); + std::string expected_yaml = Unindent(param.expected_yaml); + EXPECT_EQ(yaml_output, expected_yaml); +} + +std::vector GetExportTestCases() { + return { + ExportTestCase{ + .config = + []() { + Config config; + config.SetName("test.env"); + config.SetContainerConfig({.name = "test.container"}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: "test.container" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + config.SetName("test.env"); + config.SetContainerConfig({.name = "test.container"}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: "test.container" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("math")); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("optional", 2)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig("bindings")); + return config; + }(), + .expected_yaml = R"yaml( + extensions: + - name: "bindings" + - name: "math" + - name: "optional" + version: 2 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .disable = true, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + disable: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .disable_macros = true, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + disable_macros: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .excluded_macros = {"map", "filter"}, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + exclude_macros: + - "filter" + - "map" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .included_macros = {"map", "filter"}, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + include_macros: + - "filter" + - "map" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .excluded_functions = + { + std::make_pair("timestamp", "string_to_timestamp"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("_+_", "add_bytes"), + }, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + exclude_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.SetStandardLibraryConfig(Config::StandardLibraryConfig{ + .included_functions = + { + std::make_pair("timestamp", "string_to_timestamp"), + std::make_pair("_+_", "add_list"), + std::make_pair("matches", ""), + std::make_pair("_+_", "add_bytes"), + }, + })); + return config; + }(), + .expected_yaml = R"yaml( + stdlib: + include_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "null"}, + .value = Constant(nullptr)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "null" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "bool"}, + .value = Constant(true)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "bool" + value: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "int"}, + .value = Constant(int64_t{42})})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "int" + value: 42 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "uint"}, + .value = Constant(uint64_t{777})})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "uint" + value: 777 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR( + config.AddVariableConfig({.name = "foo", + .type_info = {.name = "double"}, + .value = Constant(0.75)})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "double" + value: 0.75 + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "bytes"}, + .value = Constant( + BytesConstant(absl::string_view("\xff\x00\x01", 3)))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "bytes" + value: b"\xff\x00\x01" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + Constant c; + c.set_string_value("'single' \"double\""); + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "string"}, + .value = Constant(StringConstant("'single' \"double\""))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "string" + value: "'single' \"double\"" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "duration"}, + .value = Constant(absl::Hours(1) + absl::Minutes(2) + + absl::Seconds(3))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "duration" + value: 1h2m3s + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = "timestamp"}, + .value = Constant(absl::FromUnixSeconds(1767323045))})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = {.name = + "google.expr.proto3.test.TestAllTypes"}})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "google.expr.proto3.test.TestAllTypes" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddVariableConfig( + {.name = "foo", + .type_info = { + .name = "A", + .params = {{.name = "int"}, + {.name = "B", .is_type_param = true}}}})); + return config; + }(), + .expected_yaml = R"yaml( + variables: + - name: "foo" + type_name: "A" + params: + - type_name: "int" + - type_name: "B" + is_type_param: true + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig({.name = "foo"})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig( + {.name = "foo", + .overload_configs = { + {.overload_id = "foo_overload_id", + .is_member_function = true, + .parameters = {{.name = "timestamp"}, + {.name = "A", .params = {{.name = "B"}}}}, + .return_type = {.name = "int"}}, + }})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + target: + type_name: "timestamp" + args: + - type_name: "A" + params: + - type_name: "B" + return: + type_name: "int" + )yaml", + }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig( + {.name = "foo", + .overload_configs = { + {.overload_id = "foo_overload_a", + .parameters = {{.name = "timestamp"}}, + .return_type = {.name = "list", + .params = {{.name = "int"}}}}, + {.overload_id = "foo_overload_b", + .parameters = {{.name = "double"}, + {.name = "A", .params = {{.name = "B"}}}}, + .return_type = {.name = "string"}}, + }})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_b" + args: + - type_name: "double" + - type_name: "A" + params: + - type_name: "B" + return: + type_name: "string" + - id: "foo_overload_a" + args: + - type_name: "timestamp" + return: + type_name: "list" + params: + - type_name: "int" + )yaml", + }, + }; +}; + +INSTANTIATE_TEST_SUITE_P(EnvYamlExportTest, EnvYamlExportTest, + ::testing::ValuesIn(GetExportTestCases())); + +class EnvYamlRoundTripTest : public testing::TestWithParam {}; + +TEST_P(EnvYamlRoundTripTest, EnvYamlRoundTrip) { + const std::string& yaml = Unindent(GetParam()); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(yaml)); + + std::stringstream ss; + EnvConfigToYaml(config, ss); + EXPECT_EQ(ss.str(), yaml); +} + +std::vector GetRoundTripTestCases() { + return { + R"yaml( + stdlib: + disable: true + disable_macros: true + )yaml", + R"yaml( + name: "test.env" + container: "common.proto.prefix" + extensions: + - name: "math" + version: 0 + - name: "optional" + version: 2 + stdlib: + include_macros: + - "filter" + - "map" + include_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + R"yaml( + extensions: + - name: "bindings" + - name: "math" + stdlib: + exclude_macros: + - "filter" + - "map" + exclude_functions: + - name: "_+_" + overloads: + - id: "add_bytes" + - id: "add_list" + - name: "matches" + - name: "timestamp" + overloads: + - id: "string_to_timestamp" + )yaml", + R"yaml( + variables: + - name: "a" + type_name: "null" + - name: "b" + type_name: "bool" + value: true + - name: "c" + type_name: "int" + value: 42 + - name: "d" + type_name: "uint" + value: 777 + - name: "e" + type_name: "double" + value: 0.75 + - name: "f" + type_name: "bytes" + value: b"\xff\x00\x01" + - name: "g" + type_name: "string" + value: "plain 'single' \"double\"" + - name: "h" + type_name: "duration" + value: 1h2m3s + - name: "i" + type_name: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + R"yaml( + functions: + - name: "bar" + - name: "foo" + )yaml", + R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + target: + type_name: "timestamp" + args: + - type_name: "A" + params: + - type_name: "B" + return: + type_name: "int" + )yaml", + R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + args: + - type_name: "timestamp" + - type_name: "A" + params: + - type_name: "B" + return: + type_name: "list" + params: + - type_name: "int" + )yaml", + }; +} + +INSTANTIATE_TEST_SUITE_P(EnvYamlRoundTripTest, EnvYamlRoundTripTest, + ::testing::ValuesIn(GetRoundTripTestCases())); + +} // namespace +} // namespace cel diff --git a/env/internal/BUILD b/env/internal/BUILD new file mode 100644 index 000000000..ec4a0b15c --- /dev/null +++ b/env/internal/BUILD @@ -0,0 +1,87 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "ext_registry", + srcs = ["ext_registry.cc"], + hdrs = ["ext_registry.h"], + deps = [ + "//compiler", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "runtime_ext_registry", + srcs = ["runtime_ext_registry.cc"], + hdrs = ["runtime_ext_registry.h"], + deps = [ + "//runtime:runtime_builder", + "//runtime:runtime_options", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "ext_registry_test", + srcs = ["ext_registry_test.cc"], + deps = [ + ":ext_registry", + "//checker:type_checker_builder", + "//compiler", + "//internal:testing", + "//parser:parser_interface", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) + +cc_test( + name = "runtime_ext_registry_test", + srcs = ["runtime_ext_registry_test.cc"], + deps = [ + ":runtime_ext_registry", + "//common:ast", + "//common:source", + "//common:value", + "//common:value_testing", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//parser:options", + "//parser:parser_interface", + "//runtime", + "//runtime:activation", + "//runtime:function", + "//runtime:function_adapter", + "//runtime:function_registry", + "//runtime:runtime_builder", + "//runtime:runtime_builder_factory", + "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/env/internal/ext_registry.cc b/env/internal/ext_registry.cc new file mode 100644 index 000000000..b32239ac3 --- /dev/null +++ b/env/internal/ext_registry.cc @@ -0,0 +1,63 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/ext_registry.h" + +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" + +namespace cel { +namespace env_internal { + +void ExtensionRegistry::RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) { + library_registry_.push_back( + LibraryRegistration(name, alias, version, std::move(library_factory))); +} + +absl::StatusOr ExtensionRegistry::GetCompilerLibrary( + absl::string_view name, int version) const { + if (version == kLatest) { + int max_version = -1; + for (const auto& registration : library_registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ > max_version) { + max_version = registration.version_; + } + } + if (max_version == -1) { + return absl::NotFoundError( + absl::StrCat("CompilerLibrary not registered: ", name)); + } + version = max_version; + } + for (const auto& registration : library_registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ == version) { + return registration.GetLibrary(); + } + } + + return absl::NotFoundError( + absl::StrCat("CompilerLibrary not registered: ", name, "#", version)); +} +} // namespace env_internal +} // namespace cel diff --git a/env/internal/ext_registry.h b/env/internal/ext_registry.h new file mode 100644 index 000000000..ab5b67a24 --- /dev/null +++ b/env/internal/ext_registry.h @@ -0,0 +1,74 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "compiler/compiler.h" + +namespace cel { +namespace env_internal { + +// A registry for CEL compiler extension libraries. +// +// Used to register and retrieve CompilerLibraries by name (or alias) and +// version. +class ExtensionRegistry { + public: + static constexpr int kLatest = std::numeric_limits::max(); + + void RegisterCompilerLibrary( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory); + + absl::StatusOr GetCompilerLibrary(absl::string_view name, + int version) const; + + private: + class LibraryRegistration final { + public: + LibraryRegistration( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable library_factory) + : name_(name), + alias_(!alias.empty() ? alias : name), + version_(version), + factory_(std::move(library_factory)) {} + + CompilerLibrary GetLibrary() const { return factory_(); } + + private: + std::string name_; + std::string alias_; + int version_; + absl::AnyInvocable factory_; + + friend class ExtensionRegistry; + }; + + std::vector library_registry_; +}; + +} // namespace env_internal +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_INTERNAL_EXT_REGISTRY_H_ diff --git a/env/internal/ext_registry_test.cc b/env/internal/ext_registry_test.cc new file mode 100644 index 000000000..9e345c781 --- /dev/null +++ b/env/internal/ext_registry_test.cc @@ -0,0 +1,73 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/ext_registry.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" +#include "internal/testing.h" +#include "parser/parser_interface.h" + +namespace cel::env_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::testing::Field; +using ::testing::HasSubstr; + +TEST(ExtensionRegistryTest, GetCompilerLibrary) { + ExtensionRegistry registry; + registry.RegisterCompilerLibrary("foo1", "f", 1, []() { + return CompilerLibrary("foo1_1", nullptr, nullptr); + }); + registry.RegisterCompilerLibrary("foo1", "f", 2, []() { + return CompilerLibrary("foo1_2", nullptr, nullptr); + }); + registry.RegisterCompilerLibrary("foo2", "", 1, []() { + return CompilerLibrary("foo2_1", nullptr, nullptr); + }); + + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("f", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 2), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo1", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("f", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo1_2"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo2", 1), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo2_1"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo2", ExtensionRegistry::kLatest), + IsOkAndHolds(Field(&CompilerLibrary::id, "foo2_1"))); + + EXPECT_THAT(registry.GetCompilerLibrary("foo1", 3), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo1#3"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo3", 1), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo3"))); + EXPECT_THAT(registry.GetCompilerLibrary("foo3", ExtensionRegistry::kLatest), + StatusIs(absl::StatusCode::kNotFound, + HasSubstr("CompilerLibrary not registered: foo3"))); +} + +} // namespace +} // namespace cel::env_internal diff --git a/env/internal/runtime_ext_registry.cc b/env/internal/runtime_ext_registry.cc new file mode 100644 index 000000000..dc78a38e3 --- /dev/null +++ b/env/internal/runtime_ext_registry.cc @@ -0,0 +1,64 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/runtime_ext_registry.h" + +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace env_internal { + +void RuntimeExtensionRegistry::AddFunctionRegistration( + absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback) { + registry_.push_back(Registration(name, alias, version, + std::move(function_registration_callback))); +} + +absl::Status RuntimeExtensionRegistry::RegisterExtensionFunctions( + RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options, + absl::string_view name, int version) const { + if (version == kLatest) { + int max_version = -1; + for (const Registration& registration : registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ > max_version) { + max_version = registration.version_; + } + } + if (max_version == -1) { + return absl::NotFoundError(absl::StrCat( + "Runtime functions are not registered for extension: ", name)); + } + version = max_version; + } + for (const Registration& registration : registry_) { + if ((registration.name_ == name || registration.alias_ == name) && + registration.version_ == version) { + return registration.RegisterExtensionFunctions(runtime_builder, + runtime_options); + } + } + + return absl::NotFoundError(absl::StrCat( + "Runtime functions are not registered for extension: ", name)); +} +} // namespace env_internal +} // namespace cel diff --git a/env/internal/runtime_ext_registry.h b/env/internal/runtime_ext_registry.h new file mode 100644 index 000000000..67838519f --- /dev/null +++ b/env/internal/runtime_ext_registry.h @@ -0,0 +1,84 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { +namespace env_internal { + +using FunctionRegistrationCallback = absl::AnyInvocable; + +// A registry for CEL runtime extension functions. +// +// Used to register runtime functions for extensions by name (or alias) and +// version. +class RuntimeExtensionRegistry { + public: + static constexpr int kLatest = std::numeric_limits::max(); + + void AddFunctionRegistration( + absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback); + + absl::Status RegisterExtensionFunctions(RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options, + absl::string_view name, + int version) const; + + private: + class Registration final { + public: + Registration(absl::string_view name, absl::string_view alias, int version, + FunctionRegistrationCallback function_registration_callback) + : name_(name), + alias_(!alias.empty() ? alias : name), + version_(version), + function_registration_callback_( + std::move(function_registration_callback)) {} + + absl::Status RegisterExtensionFunctions( + RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) const { + return function_registration_callback_(runtime_builder, runtime_options); + } + + private: + std::string name_; + std::string alias_; + int version_; + FunctionRegistrationCallback function_registration_callback_; + + friend class RuntimeExtensionRegistry; + }; + + std::vector registry_; +}; + +} // namespace env_internal +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_RUNTIME_EXT_REGISTRY_H_ diff --git a/env/internal/runtime_ext_registry_test.cc b/env/internal/runtime_ext_registry_test.cc new file mode 100644 index 000000000..c6125d20f --- /dev/null +++ b/env/internal/runtime_ext_registry_test.cc @@ -0,0 +1,126 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/internal/runtime_ext_registry.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/source.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/options.h" +#include "parser/parser.h" +#include "parser/parser_interface.h" +#include "runtime/activation.h" +#include "runtime/function.h" +#include "runtime/function_adapter.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_builder_factory.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" + +namespace cel::env_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::cel::test::StringValueIs; + +Value Hello1(const StringValue& input, const Function::InvokeContext& context) { + return StringValue::From("Hello, old " + input.ToString() + "!", + context.arena()); +} + +Value Hello2(const StringValue& input, const Function::InvokeContext& context) { + return StringValue::From("Hello, new " + input.ToString() + "!", + context.arena()); +} + +RuntimeExtensionRegistry GetRuntimeExtensionRegistry() { + RuntimeExtensionRegistry registry; + registry.AddFunctionRegistration( + "hello_extension", "hello_extension_alias", 1, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::UnaryFunctionAdapter:: + RegisterGlobalOverload("hello", &Hello1, + runtime_builder.function_registry()); + }); + registry.AddFunctionRegistration( + "hello_extension", "hello_extension_alias", 2, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::UnaryFunctionAdapter:: + RegisterMemberOverload("hello", &Hello2, + runtime_builder.function_registry()); + }); + return registry; +} + +class RuntimeExtensionRegistryTest : public testing::Test { + protected: + absl::StatusOr Run(std::string_view extension_name, int version, + std::string_view expr) { + const RuntimeExtensionRegistry registry = GetRuntimeExtensionRegistry(); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr parser, + NewParserBuilder(ParserOptions())->Build()); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr source, NewSource(expr, "")); + CEL_ASSIGN_OR_RETURN(std::unique_ptr ast, parser->Parse(*source)); + + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + cel::RuntimeOptions runtime_options; + CEL_ASSIGN_OR_RETURN( + cel::RuntimeBuilder runtime_builder, + cel::CreateRuntimeBuilder(descriptor_pool, runtime_options)); + + CEL_RETURN_IF_ERROR(registry.RegisterExtensionFunctions( + runtime_builder, runtime_options, extension_name, version)); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + std::move(runtime_builder).Build()); + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + + Activation activation; + return program->Evaluate(&arena_, activation); + } + + private: + google::protobuf::Arena arena_; +}; + +TEST_F(RuntimeExtensionRegistryTest, SpecificExtensionVersion) { + EXPECT_THAT(Run("hello_extension", 1, "hello('world')"), + IsOkAndHolds(StringValueIs("Hello, old world!"))); +} + +TEST_F(RuntimeExtensionRegistryTest, LatestExtensionVersion) { + EXPECT_THAT(Run("hello_extension_alias", RuntimeExtensionRegistry::kLatest, + "'world'.hello()"), + IsOkAndHolds(StringValueIs("Hello, new world!"))); +} + +} // namespace +} // namespace cel::env_internal diff --git a/env/runtime_std_extensions.cc b/env/runtime_std_extensions.cc new file mode 100644 index 000000000..167e3b104 --- /dev/null +++ b/env/runtime_std_extensions.cc @@ -0,0 +1,130 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/runtime_std_extensions.h" + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "checker/optional.h" +#include "env/env_runtime.h" +#include "env/internal/runtime_ext_registry.h" +#include "extensions/encoders.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext.h" +#include "extensions/math_ext_decls.h" +#include "extensions/regex_ext.h" +#include "extensions/sets_functions.h" +#include "extensions/strings.h" +#include "runtime/optional_types.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" + +namespace cel { + +void RegisterStandardExtensions(EnvRuntime& env_runtime) { + env_internal::RuntimeExtensionRegistry& registry = + env_runtime.GetRuntimeExtensionRegistry(); + registry.AddFunctionRegistration( + "cel.lib.ext.bindings", "bindings", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.encoders", "encoders", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterEncodersFunctions( + runtime_builder.function_registry(), runtime_options); + }); + + for (int version = 0; version <= extensions::kListsExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.lists", "lists", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterListsFunctions( + runtime_builder.function_registry(), runtime_options, version); + }); + } + + for (int version = 0; version <= extensions::kMathExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.math", "math", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterMathExtensionFunctions( + runtime_builder.function_registry(), runtime_options, version); + }); + } + + for (int version = 0; version <= cel::kOptionalExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.optional", "optional", version, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::EnableOptionalTypes(runtime_builder); + }); + } + + registry.AddFunctionRegistration( + "cel.lib.ext.protos", "protos", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.sets", "sets", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterSetsFunctions( + runtime_builder.function_registry(), runtime_options); + }); + + for (int version = 0; version <= extensions::kStringsExtensionLatestVersion; + ++version) { + registry.AddFunctionRegistration( + "cel.lib.ext.strings", "strings", version, + [version](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterStringsFunctions( + runtime_builder.function_registry(), runtime_options, version); + }); + } + + registry.AddFunctionRegistration( + "cel.lib.ext.comprev2", "two-var-comprehensions", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + // No runtime functions to register. + return absl::OkStatus(); + }); + + registry.AddFunctionRegistration( + "cel.lib.ext.regex", "regex", 0, + [](RuntimeBuilder& runtime_builder, + const RuntimeOptions& runtime_options) -> absl::Status { + return cel::extensions::RegisterRegexExtensionFunctions( + runtime_builder); + }); +} + +} // namespace cel diff --git a/env/runtime_std_extensions.h b/env/runtime_std_extensions.h new file mode 100644 index 000000000..d7f714226 --- /dev/null +++ b/env/runtime_std_extensions.h @@ -0,0 +1,46 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ + +#include "env/env_runtime.h" + +namespace cel { + +// Registers the standard CEL extension functions with the given environment +// runtime. This makes them available, but does not enable them. See Env::Config +// for how to enable extensions. +// +// Included in the standard runtime environment: +// +// - cel.lib.ext.bindings (alias: "bindings") +// - cel.lib.ext.encoders (alias: "encoders") +// - cel.lib.ext.lists (alias: "lists") +// - cel.lib.ext.math (alias: "math") +// - optional +// - cel.lib.ext.protos (alias: "protos") +// - cel.lib.ext.sets (alias: "sets") +// - cel.lib.ext.strings (alias: "strings") +// - cel.lib.ext.comprev2 (alias: "two-var-comprehensions") +// +// NOTE: Not included in the standard runtime environment yet - include manually +// if needed: +// - cel.lib.ext.regex (alias: "regex") +// +void RegisterStandardExtensions(EnvRuntime& env_runtime); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_RUNTIME_STD_EXTENSIONS_H_ diff --git a/env/runtime_std_extensions_test.cc b/env/runtime_std_extensions_test.cc new file mode 100644 index 000000000..4c7cb9829 --- /dev/null +++ b/env/runtime_std_extensions_test.cc @@ -0,0 +1,229 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/runtime_std_extensions.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "checker/optional.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_runtime.h" +#include "env/env_std_extensions.h" +#include "extensions/lists_functions.h" +#include "extensions/math_ext_decls.h" +#include "extensions/strings.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::IsEmpty; +using ::testing::ValuesIn; + +struct TestCase { + std::string extension_name; + std::vector extension_versions = {0}; + int latest_extension_version = 0; + std::string expr; + bool requires_optional_extension = false; +}; + +using RuntimeStdExtensionTest = testing::TestWithParam; + +TEST_P(RuntimeStdExtensionTest, RegisterStandardExtensions) { + const TestCase& param = GetParam(); + Env env; + env.SetDescriptorPool(cel::internal::GetSharedTestingDescriptorPool()); + RegisterStandardExtensions(env); + + Config compiler_config; + // For the compilation step, assume latest version of the extension to ensure + // a successful compilation. Later, we will test the runtime with different + // extension versions. + ASSERT_THAT(compiler_config.AddExtensionConfig( + param.extension_name, Config::ExtensionConfig::kLatest), + IsOk()); + env.SetConfig(compiler_config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile(param.expr)); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); + + for (int version = 0; version <= param.latest_extension_version; ++version) { + Config runtime_config; + // Request a specific version of the extension to be configured in the + // runtime. + ASSERT_THAT( + runtime_config.AddExtensionConfig(param.extension_name, version), + IsOk()); + if (param.requires_optional_extension) { + ASSERT_THAT(runtime_config.AddExtensionConfig("optional"), IsOk()); + } + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool( + cel::internal::GetSharedTestingDescriptorPool()); + RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(runtime_config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + absl::StatusOr> program_or = + runtime->CreateProgram(std::make_unique(*ast)); + + // If the function is not supported in this extension version, check that + // the program creation returned an error. + if (!absl::c_contains(param.extension_versions, version)) { + EXPECT_THAT(program_or, StatusIs(absl::StatusCode::kInvalidArgument)) + << " expr: " << param.expr << " version: " << version; + continue; + } + + ASSERT_THAT(program_or, IsOk()) + << " expr: " << param.expr << " version: " << version; + std::unique_ptr program = *std::move(program_or); + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()) + << " expr: " << param.expr << " version: " << version; + } +} + +std::vector GetRuntimeStdExtensionTestCases() { + return { + TestCase{ + // The "bindings" extension does not register any runtime functions - + // only macros. + .extension_name = "bindings", + .expr = "cel.bind(t, 42, t + 1) == 43", + }, + TestCase{ + .extension_name = "encoders", + .expr = "base64.encode(b'hello') == 'aGVsbG8='", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {0, 1, 2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[3, 2, 1].slice(0, 1) == [3]", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {1, 2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[[1, 2], 3].flatten() == [1, 2, 3]", + }, + TestCase{ + .extension_name = "lists", + .extension_versions = {2}, + .latest_extension_version = extensions::kListsExtensionLatestVersion, + .expr = "[3, 2, 1].sort() == [1, 2, 3]", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {0, 1, 2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.least([1, -2, 3]) == -2", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {1, 2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.floor(42.9) == 42.0", + }, + TestCase{ + .extension_name = "math", + .extension_versions = {2}, + .latest_extension_version = extensions::kMathExtensionLatestVersion, + .expr = "math.sqrt(4) == 2.0", + }, + TestCase{ + .extension_name = "optional", + .extension_versions = {0, 1, 2}, + .latest_extension_version = kOptionalExtensionLatestVersion, + .expr = "optional.of(1).hasValue()", + }, + TestCase{ + // No runtime functions. + .extension_name = "protos", + .expr = "!proto.hasExt(cel.expr.conformance.proto2.TestAllTypes{}, " + "cel.expr.conformance.proto2.nested_ext)", + }, + TestCase{ + .extension_name = "sets", + .expr = "sets.contains([1], [1])", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {0, 1, 2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "'Hello, who!'.replace('who', 'World') == 'Hello, World!'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {1, 2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "strings.quote('hello') == '\"hello\"'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {2, 3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "['hello', 'world'].join(', ') == 'hello, world'", + }, + TestCase{ + .extension_name = "strings", + .extension_versions = {3, 4}, + .latest_extension_version = + extensions::kStringsExtensionLatestVersion, + .expr = "'stressed'.reverse() == 'desserts'", + }, + TestCase{ + // No runtime functions. + .extension_name = "cel.lib.ext.comprev2", + .expr = "[1, 2, 3].map(i, i * 2) == [2, 4, 6]", + }, + TestCase{ + .extension_name = "cel.lib.ext.regex", + .expr = "regex.replace('abc', '$', '_end') == 'abc_end'", + .requires_optional_extension = true, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(RuntimeStdExtensionTest, RuntimeStdExtensionTest, + ValuesIn(GetRuntimeStdExtensionTestCases())); + +} // namespace +} // namespace cel From f4d72d97c79f45c323cbe1036cfc067953bc070a Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 24 Mar 2026 19:36:40 -0700 Subject: [PATCH 11/88] Add planner support for checking runtime extensions in the AST. Introduces check_ast_extensions to extract and validate runtime-affecting extensions from the SourceInfo. Currently, flat_expr_builder returns an error if any runtime extensions are present, as support is not yet implemented. Fixes ExtensionSpec copy constructor and assignment to correctly handle null version_. PiperOrigin-RevId: 888974129 --- common/ast/metadata.cc | 10 +- common/ast/metadata_test.cc | 31 ++++++ eval/compiler/BUILD | 29 +++++- eval/compiler/check_ast_extensions.cc | 58 +++++++++++ eval/compiler/check_ast_extensions.h | 34 +++++++ eval/compiler/check_ast_extensions_test.cc | 110 +++++++++++++++++++++ eval/compiler/flat_expr_builder.cc | 32 ++++++ eval/compiler/flat_expr_builder_test.cc | 42 +++++--- 8 files changed, 328 insertions(+), 18 deletions(-) create mode 100644 eval/compiler/check_ast_extensions.cc create mode 100644 eval/compiler/check_ast_extensions.h create mode 100644 eval/compiler/check_ast_extensions_test.cc diff --git a/common/ast/metadata.cc b/common/ast/metadata.cc index f744deb00..eecb0dbb3 100644 --- a/common/ast/metadata.cc +++ b/common/ast/metadata.cc @@ -61,12 +61,18 @@ const ExtensionSpec& ExtensionSpec::DefaultInstance() { ExtensionSpec::ExtensionSpec(const ExtensionSpec& other) : id_(other.id_), affected_components_(other.affected_components_), - version_(std::make_unique(*other.version_)) {} + version_(other.version_ == nullptr + ? nullptr + : std::make_unique(*other.version_)) {} ExtensionSpec& ExtensionSpec::operator=(const ExtensionSpec& other) { id_ = other.id_; affected_components_ = other.affected_components_; - version_ = std::make_unique(*other.version_); + if (other.version_ != nullptr) { + version_ = std::make_unique(other.version()); + } else { + version_ = nullptr; + } return *this; } diff --git a/common/ast/metadata_test.cc b/common/ast/metadata_test.cc index 4afb0d07d..5553f4c8f 100644 --- a/common/ast/metadata_test.cc +++ b/common/ast/metadata_test.cc @@ -25,6 +25,8 @@ namespace cel { namespace { +using ::testing::ElementsAre; + TEST(AstTest, ListTypeSpecMutableConstruction) { ListTypeSpec type; type.mutable_elem_type() = TypeSpec(PrimitiveType::kBool); @@ -264,5 +266,34 @@ TEST(AstTest, ExtensionSpecEquality) { std::make_unique(0, 0), {})); } +TEST(AstTest, ExtensionCopyMove) { + ExtensionSpec a("constant_folding", nullptr, {}); + a.mutable_version().set_major(1); + a.mutable_version().set_minor(2); + a.mutable_affected_components().push_back(ExtensionSpec::Component::kRuntime); + + ExtensionSpec b(a); + + EXPECT_EQ(b.id(), "constant_folding"); + EXPECT_EQ(b.version().major(), 1); + EXPECT_EQ(b.version().minor(), 2); + EXPECT_THAT(b.affected_components(), + ElementsAre(ExtensionSpec::Component::kRuntime)); + + ExtensionSpec c(std::move(b)); + EXPECT_EQ(c, a); + + a.set_version(nullptr); + b = a; + EXPECT_EQ(b.id(), "constant_folding"); + EXPECT_EQ(b.version().major(), 0); + EXPECT_EQ(b.version().minor(), 0); + EXPECT_THAT(b.affected_components(), + ElementsAre(ExtensionSpec::Component::kRuntime)); + + c = std::move(b); + EXPECT_EQ(c, a); +} + } // namespace } // namespace cel diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 62b208772..ed8e4d20c 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -33,7 +33,6 @@ cc_library( "//base:data", "//common:expr", "//common:native_type", - "//common:navigable_ast", "//common:value", "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", @@ -96,6 +95,7 @@ cc_library( "flat_expr_builder.h", ], deps = [ + ":check_ast_extensions", ":flat_expr_builder_extensions", ":resolver", "//base:ast", @@ -413,6 +413,33 @@ cc_library( ], ) +cc_library( + name = "check_ast_extensions", + srcs = ["check_ast_extensions.cc"], + hdrs = ["check_ast_extensions.h"], + deps = [ + "//common:ast", + "//common/ast:metadata", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "check_ast_extensions_test", + srcs = ["check_ast_extensions_test.cc"], + deps = [ + ":check_ast_extensions", + "//common:ast", + "//common:expr", + "//common/ast:metadata", + "//internal:testing", + "@com_google_absl//absl/status", + ], +) + cc_library( name = "resolver", srcs = ["resolver.cc"], diff --git a/eval/compiler/check_ast_extensions.cc b/eval/compiler/check_ast_extensions.cc new file mode 100644 index 000000000..37181b535 --- /dev/null +++ b/eval/compiler/check_ast_extensions.cc @@ -0,0 +1,58 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/check_ast_extensions.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/ast/metadata.h" + +namespace google::api::expr::runtime { + +absl::StatusOr> +ExtractAndValidateRuntimeExtensions(const cel::Ast& ast) { + std::vector runtime_extensions; + absl::flat_hash_set seen_extension_ids; + + for (const cel::ExtensionSpec& extension : ast.source_info().extensions()) { + bool is_runtime = false; + for (const cel::ExtensionSpec::Component& component : + extension.affected_components()) { + if (component == cel::ExtensionSpec::Component::kRuntime) { + is_runtime = true; + break; + } + } + + if (!is_runtime) { + continue; + } + + if (!seen_extension_ids.insert(extension.id()).second) { + return absl::InvalidArgumentError( + absl::StrCat("duplicate extension ID: ", extension.id())); + } + runtime_extensions.push_back(extension); + } + + return runtime_extensions; +} + +} // namespace google::api::expr::runtime diff --git a/eval/compiler/check_ast_extensions.h b/eval/compiler/check_ast_extensions.h new file mode 100644 index 000000000..443c6ac09 --- /dev/null +++ b/eval/compiler/check_ast_extensions.h @@ -0,0 +1,34 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ + +#include + +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/ast/metadata.h" + +namespace google::api::expr::runtime { + +// Extracts and validates extension tags from the AST `ast` that affect the +// runtime component. Returns the validated list of runtime extensions, or an +// error if there are multiple runtime extensions with the same ID. +absl::StatusOr> +ExtractAndValidateRuntimeExtensions(const cel::Ast& ast); + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_CHECK_AST_EXTENSIONS_H_ diff --git a/eval/compiler/check_ast_extensions_test.cc b/eval/compiler/check_ast_extensions_test.cc new file mode 100644 index 000000000..9e5838905 --- /dev/null +++ b/eval/compiler/check_ast_extensions_test.cc @@ -0,0 +1,110 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/compiler/check_ast_extensions.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "common/ast.h" +#include "common/ast/metadata.h" +#include "common/expr.h" +#include "internal/testing.h" + +namespace google::api::expr::runtime { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::Ast; +using ::cel::Expr; +using ::cel::ExtensionSpec; +using ::cel::SourceInfo; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Property; +using ::testing::SizeIs; + +TEST(ExtractAndValidateRuntimeExtensionsTest, EmptyExtensions) { + Ast ast(Expr{}, SourceInfo{}); + EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), + IsOkAndHolds(SizeIs(0))); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, FiltersNonRuntimeExtensions) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kParser})); + source_info.mutable_extensions().push_back( + ExtensionSpec("ext2", nullptr, {ExtensionSpec::Component::kTypeChecker})); + + Ast ast(Expr(), std::move(source_info)); + + EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), + IsOkAndHolds(SizeIs(0))); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, ExtractsRuntimeExtensions) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back(ExtensionSpec( + "ext2", nullptr, + {ExtensionSpec::Component::kParser, ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back( + ExtensionSpec("ext3", nullptr, {ExtensionSpec::Component::kParser})); + + Ast ast(Expr(), std::move(source_info)); + + auto result = ExtractAndValidateRuntimeExtensions(ast); + ASSERT_THAT(result, IsOk()); + EXPECT_THAT(*result, ElementsAre(Property(&ExtensionSpec::id, Eq("ext1")), + Property(&ExtensionSpec::id, Eq("ext2")))); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, FailsOnDuplicateRuntimeID) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back(ExtensionSpec( + "ext1", nullptr, + {ExtensionSpec::Component::kParser, ExtensionSpec::Component::kRuntime})); + + Ast ast(Expr(), std::move(source_info)); + + EXPECT_THAT(ExtractAndValidateRuntimeExtensions(ast), + StatusIs(absl::StatusCode::kInvalidArgument, + "duplicate extension ID: ext1")); +} + +TEST(ExtractAndValidateRuntimeExtensionsTest, IgnoresDuplicateNonRuntimeID) { + SourceInfo source_info; + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kRuntime})); + source_info.mutable_extensions().push_back( + ExtensionSpec("ext1", nullptr, {ExtensionSpec::Component::kParser})); + + Ast ast(Expr(), std::move(source_info)); + + auto result = ExtractAndValidateRuntimeExtensions(ast); + ASSERT_THAT(result, IsOk()); + EXPECT_THAT(*result, ElementsAre(Property(&ExtensionSpec::id, Eq("ext1")))); +} + +} // namespace +} // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 91822092c..e38c912c0 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -60,6 +60,7 @@ #include "common/kind.h" #include "common/type.h" #include "common/value.h" +#include "eval/compiler/check_ast_extensions.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" #include "eval/eval/comprehension_step.h" @@ -2512,6 +2513,22 @@ std::vector FlattenExpressionTable( return subexpression_indexes; } +absl::Status CheckAstExtensions( + const std::vector& extensions) { + for (const cel::ExtensionSpec& extension : extensions) { + if (extension.id() == "cel_block" && extension.version().major() == 1) { + // cel_block v1 is always supported. + continue; + } + + // TODO(uncreated-issue/89): Add support for json field names. + return absl::InvalidArgumentError(absl::StrCat( + "unsupported CEL extension: ", extension.id(), "@", + extension.version().major(), ".", extension.version().minor())); + } + return absl::OkStatus(); +} + } // namespace absl::StatusOr FlatExprBuilder::CreateExpressionImpl( @@ -2525,6 +2542,21 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( ? RuntimeIssue::Severity::kWarning : RuntimeIssue::Severity::kError; IssueCollector issue_collector(max_severity); + + absl::StatusOr> runtime_extensions = + ExtractAndValidateRuntimeExtensions(*ast); + + if (!runtime_extensions.ok()) { + CEL_RETURN_IF_ERROR(issue_collector.AddIssue( + RuntimeIssue::CreateError(runtime_extensions.status()))); + } + + auto status = CheckAstExtensions(*runtime_extensions); + if (!status.ok()) { + CEL_RETURN_IF_ERROR( + issue_collector.AddIssue(RuntimeIssue::CreateError(status))); + } + Resolver resolver(container_, function_registry_, type_registry_, GetTypeProvider(), options_.enable_qualified_type_identifiers); diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 2b705398a..5fc20f01e 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -1,18 +1,16 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "eval/compiler/flat_expr_builder.h" @@ -187,6 +185,20 @@ TEST(FlatExprBuilderTest, ExprUnset) { HasSubstr("Invalid empty expression"))); } +TEST(FlatExprBuilderTest, RuntimeExtensionsError) { + Expr expr; + SourceInfo source_info; + auto* ext = source_info.add_extensions(); + ext->set_id("ext1"); + ext->add_affected_components( + cel::expr::SourceInfo_Extension_Component_COMPONENT_RUNTIME); + + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); + EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("unsupported CEL extension: ext1"))); +} + TEST(FlatExprBuilderTest, ConstValueUnset) { Expr expr; SourceInfo source_info; From f6cd0c895e42954475b3734c867e8902557d1b37 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Wed, 25 Mar 2026 10:49:56 -0700 Subject: [PATCH 12/88] Fix YAML syntax error reporting on non-map YAML (e.g. "hello") PiperOrigin-RevId: 889328156 --- env/env_yaml.cc | 9 +++++++++ env/env_yaml_test.cc | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 0035709e9..5e7c9631d 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -998,6 +998,15 @@ void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out) { absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { Config config; CEL_ASSIGN_OR_RETURN(YAML::Node root, LoadYaml(yaml)); + if (!root.IsDefined() || root.IsNull()) { + return config; + } + + if (!root.IsMap()) { + return absl::InvalidArgumentError(FormatYamlErrorMessage( + yaml, "Invalid CEL environment config YAML", root.Mark())); + } + CEL_RETURN_IF_ERROR(ParseName(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseContainerConfig(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseExtensionConfigs(config, yaml, root)); diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index 828a39b48..b34c25254 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -528,6 +528,12 @@ TEST_P(EnvYamlParseTest, EnvYamlSyntaxError) { INSTANTIATE_TEST_SUITE_P( EnvYamlParseTest, EnvYamlParseTest, ::testing::Values( + ParseTestCase{ + .yaml = R"yaml( invalid yaml )yaml", + .expected_error = "1:2: Invalid CEL environment config YAML\n" + "| invalid yaml \n" + "| ^", + }, ParseTestCase{ .yaml = R"yaml( name: From 89d4f5fae45941cfb6a258256ead905a1d443414 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 27 Mar 2026 14:45:46 -0700 Subject: [PATCH 13/88] Update the CelExpressionBuilder documentation, relaxing the input parameter lifetime requirement PiperOrigin-RevId: 890645482 --- eval/public/cel_expression.h | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index 3f52ad60d..4cf029e89 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -80,10 +80,10 @@ class CelExpressionBuilder { virtual ~CelExpressionBuilder() = default; // Creates CelExpression object from AST tree. - // expr specifies root of AST tree - // - // IMPORTANT: The `expr` and `source_info` must outlive the resulting - // CelExpression. + // expr specifies root of AST tree. + // Method implementation is expected to create copies of expr and source_info, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::Expr* expr, const cel::expr::SourceInfo* source_info) const = 0; @@ -91,9 +91,9 @@ class CelExpressionBuilder { // Creates CelExpression object from AST tree. // expr specifies root of AST tree. // non-fatal build warnings are written to warnings if encountered. - // - // IMPORTANT: The `expr` and `source_info` must outlive the resulting - // CelExpression. + // Method implementation is expected to create copies of expr and source_info, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::Expr* expr, const cel::expr::SourceInfo* source_info, @@ -101,8 +101,9 @@ class CelExpressionBuilder { // Creates CelExpression object from a checked expression. // This includes an AST, source info, type hints and ident hints. - // - // IMPORTANT: The `checked_expr` must outlive the resulting CelExpression. + // Method implementation is expected to create copy of checked_expr, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::CheckedExpr* checked_expr) const { // Default implementation just passes through the expr and source info. @@ -113,8 +114,9 @@ class CelExpressionBuilder { // Creates CelExpression object from a checked expression. // This includes an AST, source info, type hints and ident hints. // non-fatal build warnings are written to warnings if encountered. - // - // IMPORTANT: The `checked_expr` must outlive the resulting CelExpression. + // Method implementation is expected to create copy of checked_expr, + // so that the returned CelExpression is not dependent on the lifetime of + // the input arguments. virtual absl::StatusOr> CreateExpression( const cel::expr::CheckedExpr* checked_expr, std::vector* warnings) const { From 4015e768b83cc2928a73e78a2bcd10c3894d0668 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 27 Mar 2026 14:56:12 -0700 Subject: [PATCH 14/88] Add option for format precision limit. High values for precision can lead to very large strings (exponential time / memory cost w.r.t format specifier), but are technically allowed. Will lower the default value in a follow-up. PiperOrigin-RevId: 890649934 --- env/runtime_std_extensions.cc | 5 ++- extensions/BUILD | 1 + extensions/formatting.cc | 60 ++++++++++++++++++++--------------- extensions/formatting.h | 13 ++++++-- extensions/formatting_test.cc | 26 +++++++++++++++ extensions/strings.cc | 19 +++++++---- extensions/strings.h | 34 +++++++++++++++++--- extensions/strings_test.cc | 44 +++++++++++++++++++++++++ 8 files changed, 161 insertions(+), 41 deletions(-) diff --git a/env/runtime_std_extensions.cc b/env/runtime_std_extensions.cc index 167e3b104..b866a5965 100644 --- a/env/runtime_std_extensions.cc +++ b/env/runtime_std_extensions.cc @@ -105,8 +105,11 @@ void RegisterStandardExtensions(EnvRuntime& env_runtime) { "cel.lib.ext.strings", "strings", version, [version](RuntimeBuilder& runtime_builder, const RuntimeOptions& runtime_options) -> absl::Status { + cel::extensions::StringsExtensionOptions strings_options; + strings_options.version = version; return cel::extensions::RegisterStringsFunctions( - runtime_builder.function_registry(), runtime_options, version); + runtime_builder.function_registry(), runtime_options, + strings_options); }); } diff --git a/extensions/BUILD b/extensions/BUILD index fe97af46a..c393ec13a 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -609,6 +609,7 @@ cc_test( "//checker:type_check_issue", "//checker:type_checker_builder", "//checker:validation_result", + "//common:ast", "//common:decl", "//common:type", "//common:value", diff --git a/extensions/formatting.cc b/extensions/formatting.cc index 6e58a7b86..935815569 100644 --- a/extensions/formatting.cc +++ b/extensions/formatting.cc @@ -14,20 +14,19 @@ #include "extensions/formatting.h" +#include #include #include #include #include #include #include -#include #include #include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/btree_map.h" -#include "absl/memory/memory.h" #include "absl/numeric/bits.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -64,7 +63,7 @@ absl::StatusOr FormatString( std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND); absl::StatusOr>> ParsePrecision( - absl::string_view format) { + absl::string_view format, int max_precision) { if (format.empty() || format[0] != '.') return std::pair{0, std::nullopt}; int64_t i = 1; @@ -80,9 +79,9 @@ absl::StatusOr>> ParsePrecision( return absl::InvalidArgumentError( "unable to convert precision specifier to integer"); } - if (precision > kMaxPrecision) { + if (precision > max_precision) { return absl::InvalidArgumentError( - absl::StrCat("precision specifier exceeds maximum of ", kMaxPrecision)); + absl::StrCat("precision specifier exceeds maximum of ", max_precision)); } return std::pair{i, precision}; } @@ -444,12 +443,13 @@ absl::StatusOr FormatScientific( } absl::StatusOr> ParseAndFormatClause( - absl::string_view format, const Value& value, + absl::string_view format, const Value& value, int max_precision, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { - CEL_ASSIGN_OR_RETURN(auto precision_pair, ParsePrecision(format)); + CEL_ASSIGN_OR_RETURN(auto precision_pair, + ParsePrecision(format, max_precision)); auto [read, precision] = precision_pair; switch (format[read]) { case 's': { @@ -494,7 +494,7 @@ absl::StatusOr> ParseAndFormatClause( } absl::StatusOr Format( - const StringValue& format_value, const ListValue& args, + const StringValue& format_value, const ListValue& args, int max_precision, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { @@ -512,43 +512,51 @@ absl::StatusOr Format( } ++i; if (i >= format.size()) { - return absl::InvalidArgumentError("unexpected end of format string"); + return ErrorValue( + absl::InvalidArgumentError("unexpected end of format string")); } if (format[i] == '%') { result.push_back('%'); continue; } if (arg_index >= args_size) { - return absl::InvalidArgumentError( - absl::StrFormat("index %d out of range", arg_index)); + return ErrorValue(absl::InvalidArgumentError( + absl::StrFormat("index %d out of range", arg_index))); } CEL_ASSIGN_OR_RETURN(auto value, args.Get(arg_index++, descriptor_pool, message_factory, arena)); - CEL_ASSIGN_OR_RETURN( - auto clause, - ParseAndFormatClause(format.substr(i), value, descriptor_pool, - message_factory, arena, clause_scratch)); - absl::StrAppend(&result, clause.second); - i += clause.first; + + auto clause = ParseAndFormatClause(format.substr(i), value, max_precision, + descriptor_pool, message_factory, arena, + clause_scratch); + if (!clause.ok()) { + return ErrorValue(std::move(clause).status()); + } + absl::StrAppend(&result, clause->second); + i += clause->first; } - return StringValue(arena, std::move(result)); + return StringValue::From(std::move(result), arena); } } // namespace -absl::Status RegisterStringFormattingFunctions(FunctionRegistry& registry, - const RuntimeOptions& options) { +absl::Status RegisterStringFormattingFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + StringsExtensionFormatOptions format_options) { + const int max_precision = + std::clamp(format_options.max_precision, 0, kMaxPrecision); CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, ListValue>:: CreateDescriptor("format", /*receiver_style=*/true), BinaryFunctionAdapter, StringValue, ListValue>:: WrapFunction( - [](const StringValue& format, const ListValue& args, - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, - google::protobuf::MessageFactory* absl_nonnull message_factory, - google::protobuf::Arena* absl_nonnull arena) { - return Format(format, args, descriptor_pool, message_factory, - arena); + [max_precision]( + const StringValue& format, const ListValue& args, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) { + return Format(format, args, max_precision, descriptor_pool, + message_factory, arena); }))); return absl::OkStatus(); } diff --git a/extensions/formatting.h b/extensions/formatting.h index bc2002006..88954857b 100644 --- a/extensions/formatting.h +++ b/extensions/formatting.h @@ -21,9 +21,18 @@ namespace cel::extensions { +struct StringsExtensionFormatOptions { + // The maximum precision to permit for formatting floating-point numbers. + int max_precision = 1000; +}; + // Register extension functions for string formatting. -absl::Status RegisterStringFormattingFunctions(FunctionRegistry& registry, - const RuntimeOptions& options); +// +// This implements (string).format([args...]) in the strings extension. Most +// users should add these functions via `extensions/strings.h` instead. +absl::Status RegisterStringFormattingFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + StringsExtensionFormatOptions format_options = {}); } // namespace cel::extensions diff --git a/extensions/formatting_test.cc b/extensions/formatting_test.cc index 824f14e45..b80fe9bc0 100644 --- a/extensions/formatting_test.cc +++ b/extensions/formatting_test.cc @@ -96,6 +96,32 @@ TEST_P(StringFormatLimitsTest, FormatLimits) { } } +TEST(StringFormatLimitsTest, MaxPrecisionOption) { + google::protobuf::Arena arena; + const RuntimeOptions options; + StringsExtensionFormatOptions format_options; + format_options.max_precision = 99; + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT(RegisterStringFormattingFunctions(builder.function_registry(), + options, format_options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("'%.100f'.format([1.123])", + "", ParserOptions{})); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + ProtobufRuntimeAdapter::CreateProgram(*runtime, expr)); + Activation activation; + + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + ASSERT_TRUE(value.Is()); + EXPECT_THAT(value.GetError().ToStatus().message(), + HasSubstr("precision specifier exceeds maximum of 99")); +} + INSTANTIATE_TEST_SUITE_P(StringFormatLimitsTest, StringFormatLimitsTest, ValuesIn({ "double('%.326f'.format([x])) == x", diff --git a/extensions/strings.cc b/extensions/strings.cc index ed6f27319..54fda20d6 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -305,9 +305,10 @@ absl::Status RegisterStringsDecls(TypeCheckerBuilder& builder, int version) { } // namespace -absl::Status RegisterStringsFunctions(FunctionRegistry& registry, - const RuntimeOptions& options, - int version) { +absl::Status RegisterStringsFunctions( + FunctionRegistry& registry, const RuntimeOptions& options, + const StringsExtensionOptions& extension_options) { + const int version = extension_options.version; CEL_RETURN_IF_ERROR(registry.Register( BinaryFunctionAdapter, StringValue, StringValue>:: CreateDescriptor("split", /*receiver_style=*/true), @@ -382,7 +383,8 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, return absl::OkStatus(); } - CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions(registry, options)); + CEL_RETURN_IF_ERROR(RegisterStringFormattingFunctions( + registry, options, {extension_options.max_precision})); CEL_RETURN_IF_ERROR( (UnaryFunctionAdapter::RegisterGlobalOverload( "strings.quote", &Quote, registry))); @@ -412,13 +414,16 @@ absl::Status RegisterStringsFunctions(FunctionRegistry& registry, absl::Status RegisterStringsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, - const google::api::expr::runtime::InterpreterOptions& options) { + const google::api::expr::runtime::InterpreterOptions& options, + const StringsExtensionOptions& extension_options) { return RegisterStringsFunctions( registry->InternalGetRegistry(), - google::api::expr::runtime::ConvertToRuntimeOptions(options)); + google::api::expr::runtime::ConvertToRuntimeOptions(options), + extension_options); } -CheckerLibrary StringsCheckerLibrary(int version) { +CheckerLibrary StringsCheckerLibrary(const StringsExtensionOptions& options) { + const int version = options.version; return {"strings", [version](TypeCheckerBuilder& builder) { return RegisterStringsDecls(builder, version); }}; diff --git a/extensions/strings.h b/extensions/strings.h index 3cbc9f19f..3ec92d603 100644 --- a/extensions/strings.h +++ b/extensions/strings.h @@ -27,21 +27,45 @@ namespace cel::extensions { constexpr int kStringsExtensionLatestVersion = 4; +struct StringsExtensionOptions { + int version = kStringsExtensionLatestVersion; + + // Maximum precision allowed for floating point format specifiers in + // format() function. This is used for both fixed and scientific notations. + // Value must be in the range [0, 1000], otherwise clamped. + // + // Does not affect default precisions for %e and %f format specifiers. + int max_precision = 1000; +}; + // Register extension functions for strings. absl::Status RegisterStringsFunctions( FunctionRegistry& registry, const RuntimeOptions& options, - int version = kStringsExtensionLatestVersion); + const StringsExtensionOptions& extension_options = {}); absl::Status RegisterStringsFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, - const google::api::expr::runtime::InterpreterOptions& options); + const google::api::expr::runtime::InterpreterOptions& options, + const StringsExtensionOptions& extension_options = {}); CheckerLibrary StringsCheckerLibrary( - int version = kStringsExtensionLatestVersion); + const StringsExtensionOptions& extension_options = {}); + +inline CheckerLibrary StringsCheckerLibrary(int version) { + StringsExtensionOptions options; + options.version = version; + return StringsCheckerLibrary(options); +} inline CompilerLibrary StringsCompilerLibrary( - int version = kStringsExtensionLatestVersion) { - return CompilerLibrary::FromCheckerLibrary(StringsCheckerLibrary(version)); + const StringsExtensionOptions& options = {}) { + return CompilerLibrary::FromCheckerLibrary(StringsCheckerLibrary(options)); +} + +inline CompilerLibrary StringsCompilerLibrary(int version) { + StringsExtensionOptions options; + options.version = version; + return StringsCompilerLibrary(options); } } // namespace cel::extensions diff --git a/extensions/strings_test.cc b/extensions/strings_test.cc index a5d56eaed..c3059808f 100644 --- a/extensions/strings_test.cc +++ b/extensions/strings_test.cc @@ -27,6 +27,7 @@ #include "checker/type_check_issue.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" +#include "common/ast.h" #include "common/decl.h" #include "common/type.h" #include "common/value.h" @@ -50,6 +51,7 @@ namespace cel::extensions { namespace { using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; using ::cel::expr::ParsedExpr; using ::google::api::expr::parser::Parse; using ::google::api::expr::parser::ParserOptions; @@ -85,6 +87,48 @@ TEST(StringsCheckerLibrary, SmokeTest) { )~bool^equals)"); } +TEST(StringsExtTest, MaxPrecisionOption) { + StringsExtensionOptions extension_options; + extension_options.max_precision = 99; + + ASSERT_OK_AND_ASSIGN( + auto compiler_builder, + NewCompilerBuilder(internal::GetTestingDescriptorPool())); + + ASSERT_THAT(compiler_builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(StringsCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + compiler_builder->Build()); + + ASSERT_OK_AND_ASSIGN( + ValidationResult result, + compiler->Compile("'abc %.100f'.format([2.0])", "")); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); + + RuntimeOptions opts; + ASSERT_OK_AND_ASSIGN( + auto runtime_builder, + CreateStandardRuntimeBuilder(internal::GetTestingDescriptorPool(), opts)); + + ASSERT_THAT(RegisterStringsFunctions(runtime_builder.function_registry(), + opts, extension_options), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + google::protobuf::Arena arena; + cel::Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, program->Evaluate(&arena, activation)); + + ASSERT_TRUE(value.Is()); + EXPECT_THAT(value.GetError().ToStatus(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("precision specifier exceeds maximum of 99"))); +} + using StringsExtFunctionsTest = testing::TestWithParam; TEST_P(StringsExtFunctionsTest, ParserAndCheckerTests) { From c93de861d872fb345512117b79aa2c54f4f93759 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 30 Mar 2026 10:32:41 -0700 Subject: [PATCH 15/88] Enable escaped backtick quoted identifiers by default. PiperOrigin-RevId: 891789266 --- parser/options.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/parser/options.h b/parser/options.h index ad03102e8..a41d16104 100644 --- a/parser/options.h +++ b/parser/options.h @@ -57,8 +57,9 @@ struct ParserOptions final { // Enables support for identifier quoting syntax: // "message.`skewer-case-field`" // - // Limited to field specifiers in select and message creation. - bool enable_quoted_identifiers = false; + // Limited to field specifiers in select and message creation, + // enabled by default + bool enable_quoted_identifiers = true; }; } // namespace cel From 5a3463337cf2a9b90b53833af2bbc1f35da90d64 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 30 Mar 2026 11:20:09 -0700 Subject: [PATCH 16/88] Expose TypeInfoToType and NewCompilerBuilder APIs PiperOrigin-RevId: 891814525 --- env/BUILD | 39 ++++++--- env/env.cc | 168 ++++----------------------------------- env/env.h | 5 ++ env/env_test.cc | 112 -------------------------- env/env_yaml.cc | 5 +- env/env_yaml_test.cc | 6 +- env/type_info.cc | 178 ++++++++++++++++++++++++++++++++++++++++++ env/type_info.h | 35 +++++++++ env/type_info_test.cc | 127 ++++++++++++++++++++++++++++++ 9 files changed, 397 insertions(+), 278 deletions(-) create mode 100644 env/type_info.cc create mode 100644 env/type_info.h create mode 100644 env/type_info_test.cc diff --git a/env/BUILD b/env/BUILD index f5ce35557..8d477cc1f 100644 --- a/env/BUILD +++ b/env/BUILD @@ -19,17 +19,28 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "config", - srcs = ["config.cc"], - hdrs = ["config.h"], + srcs = [ + "config.cc", + "type_info.cc", + ], + hdrs = [ + "config.h", + "type_info.h", + ], deps = [ "//common:constant", + "//common:type", + "//common:type_kind", "//internal:status_macros", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", ], ) @@ -43,23 +54,16 @@ cc_library( "//common:constant", "//common:decl", "//common:type", - "//common:type_kind", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", "//env/internal:ext_registry", "//internal:status_macros", "//parser:macro", - "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/functional:overload", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:variant", "@com_google_protobuf//:protobuf", ], ) @@ -163,6 +167,20 @@ cc_test( ], ) +cc_test( + name = "type_info_test", + srcs = ["type_info_test.cc"], + deps = [ + ":config", + "//common:type", + "//common:type_proto", + "//internal:proto_matchers", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "env_test", srcs = ["env_test.cc"], @@ -173,7 +191,6 @@ cc_test( "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", - "//common:ast_proto", "//common:constant", "//common:decl", "//common:expr", diff --git a/env/env.cc b/env/env.cc index 2c2555f14..5a4198497 100644 --- a/env/env.cc +++ b/env/env.cc @@ -15,12 +15,10 @@ #include "env/env.h" #include -#include #include #include #include -#include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -28,11 +26,11 @@ #include "common/constant.h" #include "common/decl.h" #include "common/type.h" -#include "common/type_kind.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" #include "env/config.h" +#include "env/type_info.h" #include "internal/status_macros.h" #include "parser/macro.h" #include "google/protobuf/arena.h" @@ -95,149 +93,6 @@ absl::StatusOr MakeStdlibSubset( return subset; } -std::optional TypeNameToTypeKind(absl::string_view type_name) { - // Excluded types: - // kUnknown - // kError - // kTypeParam - // kFunction - // kEnum - - static const absl::NoDestructor< - absl::flat_hash_map> - kTypeNameToTypeKind({ - {"null", TypeKind::kNull}, - {"bool", TypeKind::kBool}, - {"int", TypeKind::kInt}, - {"uint", TypeKind::kUint}, - {"double", TypeKind::kDouble}, - {"string", TypeKind::kString}, - {"bytes", TypeKind::kBytes}, - {"timestamp", TypeKind::kTimestamp}, - {TimestampType::kName, TypeKind::kTimestamp}, - {"duration", TypeKind::kDuration}, - {DurationType::kName, TypeKind::kDuration}, - {"list", TypeKind::kList}, - {"map", TypeKind::kMap}, - {"", TypeKind::kDyn}, - {"any", TypeKind::kAny}, - {"dyn", TypeKind::kDyn}, - {BoolWrapperType::kName, TypeKind::kBoolWrapper}, - {IntWrapperType::kName, TypeKind::kIntWrapper}, - {UintWrapperType::kName, TypeKind::kUintWrapper}, - {DoubleWrapperType::kName, TypeKind::kDoubleWrapper}, - {StringWrapperType::kName, TypeKind::kStringWrapper}, - {BytesWrapperType::kName, TypeKind::kBytesWrapper}, - {"type", TypeKind::kType}, - }); - if (auto it = kTypeNameToTypeKind->find(type_name); - it != kTypeNameToTypeKind->end()) { - return it->second; - } - - return std::nullopt; -} - -absl::StatusOr TypeInfoToType( - const Config::TypeInfo& type_info, google::protobuf::Arena* arena, - const google::protobuf::DescriptorPool* descriptor_pool) { - if (type_info.is_type_param) { - return TypeParamType(type_info.name); - } - - std::optional type_kind = TypeNameToTypeKind(type_info.name); - if (!type_kind.has_value()) { - if (type_info.params.empty() && descriptor_pool != nullptr) { - const google::protobuf::Descriptor* type = - descriptor_pool->FindMessageTypeByName(type_info.name); - if (type != nullptr) { - return MessageType(type); - } - } - // TODO(uncreated-issue/88): use a TypeIntrospector to validate opaque types - std::vector parameter_types; - for (const Config::TypeInfo& param : type_info.params) { - CEL_ASSIGN_OR_RETURN(Type parameter_type, - TypeInfoToType(param, arena, descriptor_pool)); - parameter_types.push_back(parameter_type); - } - - return OpaqueType(arena, type_info.name, parameter_types); - } - - switch (*type_kind) { - case TypeKind::kNull: - return NullType(); - case TypeKind::kBool: - return BoolType(); - case TypeKind::kInt: - return IntType(); - case TypeKind::kUint: - return UintType(); - case TypeKind::kDouble: - return DoubleType(); - case TypeKind::kString: - return StringType(); - case TypeKind::kBytes: - return BytesType(); - case TypeKind::kDuration: - return DurationType(); - case TypeKind::kTimestamp: - return TimestampType(); - case TypeKind::kList: { - Type element_type; - if (!type_info.params.empty()) { - CEL_ASSIGN_OR_RETURN( - element_type, - TypeInfoToType(type_info.params[0], arena, descriptor_pool)); - } else { - element_type = DynType(); - } - return ListType(arena, element_type); - } - case TypeKind::kMap: { - Type key_type = DynType(); - Type value_type = DynType(); - if (!type_info.params.empty()) { - CEL_ASSIGN_OR_RETURN(key_type, TypeInfoToType(type_info.params[0], - arena, descriptor_pool)); - } - if (type_info.params.size() > 1) { - CEL_ASSIGN_OR_RETURN( - value_type, - TypeInfoToType(type_info.params[1], arena, descriptor_pool)); - } - return MapType(arena, key_type, value_type); - } - case TypeKind::kDyn: - return DynType(); - case TypeKind::kAny: - return AnyType(); - case TypeKind::kBoolWrapper: - return BoolWrapperType(); - case TypeKind::kIntWrapper: - return IntWrapperType(); - case TypeKind::kUintWrapper: - return UintWrapperType(); - case TypeKind::kDoubleWrapper: - return DoubleWrapperType(); - case TypeKind::kStringWrapper: - return StringWrapperType(); - case TypeKind::kBytesWrapper: - return BytesWrapperType(); - case TypeKind::kType: { - if (type_info.params.empty()) { - return TypeType(arena, DynType()); - } - CEL_ASSIGN_OR_RETURN(Type type, TypeInfoToType(type_info.params[0], arena, - descriptor_pool)); - return TypeType(arena, type); - } - default: - return DynType(); - } -} - absl::StatusOr FunctionConfigToFunctionDecl( const Config::FunctionConfig& function_config, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool* descriptor_pool) { @@ -250,12 +105,12 @@ absl::StatusOr FunctionConfigToFunctionDecl( overload_decl.set_member(overload_config.is_member_function); for (const Config::TypeInfo& parameter : overload_config.parameters) { CEL_ASSIGN_OR_RETURN(Type parameter_type, - TypeInfoToType(parameter, arena, descriptor_pool)); + TypeInfoToType(parameter, descriptor_pool, arena)); overload_decl.mutable_args().push_back(parameter_type); } CEL_ASSIGN_OR_RETURN( Type return_type, - TypeInfoToType(overload_config.return_type, arena, descriptor_pool)); + TypeInfoToType(overload_config.return_type, descriptor_pool, arena)); overload_decl.set_result(return_type); CEL_RETURN_IF_ERROR(function_decl.AddOverload(overload_decl)); } @@ -264,7 +119,11 @@ absl::StatusOr FunctionConfigToFunctionDecl( } // namespace -absl::StatusOr> Env::NewCompiler() { +Env::Env() { + compiler_options_.parser_options.enable_quoted_identifiers = true; +} + +absl::StatusOr> Env::NewCompilerBuilder() { CEL_ASSIGN_OR_RETURN( std::unique_ptr compiler_builder, cel::NewCompilerBuilder(descriptor_pool_, compiler_options_)); @@ -295,8 +154,8 @@ absl::StatusOr> Env::NewCompiler() { VariableDecl variable_decl; variable_decl.set_name(variable_config.name); CEL_ASSIGN_OR_RETURN(Type type, - TypeInfoToType(variable_config.type_info, arena, - descriptor_pool_.get())); + TypeInfoToType(variable_config.type_info, + descriptor_pool_.get(), arena)); variable_decl.set_type(type); if (variable_config.value.has_value()) { variable_decl.set_value(variable_config.value); @@ -312,7 +171,12 @@ absl::StatusOr> Env::NewCompiler() { CEL_RETURN_IF_ERROR(checker_builder.AddFunction(function_decl)); } - return compiler_builder->Build(); + return compiler_builder; } +absl::StatusOr> Env::NewCompiler() { + CEL_ASSIGN_OR_RETURN(std::unique_ptr compiler_builder, + NewCompilerBuilder()); + return compiler_builder->Build(); +} } // namespace cel diff --git a/env/env.h b/env/env.h index f46e5947c..9830b67d7 100644 --- a/env/env.h +++ b/env/env.h @@ -36,6 +36,8 @@ namespace cel { // customizable CEL features. class Env { public: + Env(); + // Registers a `CompilerLibrary` with the environment. Note that the library // does not automatically get added to a `Compiler`. `NewCompiler` relies // on `Config` to determine which libraries to load. @@ -57,6 +59,9 @@ class Env { void SetConfig(const Config& config) { config_ = config; } + absl::StatusOr> NewCompilerBuilder(); + + // Shortcut for NewCompilerBuilder() followed by Build(). absl::StatusOr> NewCompiler(); private: diff --git a/env/env_test.cc b/env/env_test.cc index dcd2d97fa..076eb57bc 100644 --- a/env/env_test.cc +++ b/env/env_test.cc @@ -30,7 +30,6 @@ #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" -#include "common/ast_proto.h" #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" @@ -38,7 +37,6 @@ #include "common/value.h" #include "compiler/compiler.h" #include "env/config.h" -#include "internal/proto_matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" @@ -52,16 +50,13 @@ #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" -#include "google/protobuf/text_format.h" namespace cel { namespace { using ::absl_testing::IsOk; -using ::cel::internal::test::EqualsProto; using ::testing::HasSubstr; using ::testing::IsEmpty; -using ::testing::NotNull; using ::testing::Property; using ::testing::UnorderedElementsAre; using ::testing::Values; @@ -319,113 +314,6 @@ TEST(ContainerConfigTest, ContainerConfig) { EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); } -struct TypeInfoTestCase { - Config::TypeInfo type_info; - std::string expected_type_pb; -}; - -using TypeInfoTest = testing::TestWithParam; - -TEST_P(TypeInfoTest, TypeInfo) { - const TypeInfoTestCase& param = GetParam(); - cel::expr::Type expected_type_pb; - ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(param.expected_type_pb, - &expected_type_pb)); - - Env env; - env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); - Config config; - Config::VariableConfig variable_config; - variable_config.name = "test"; - variable_config.type_info = param.type_info; - ASSERT_THAT(config.AddVariableConfig(variable_config), IsOk()); - env.SetConfig(config); - - ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); - ASSERT_THAT(compiler, NotNull()); - ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("test")); - EXPECT_THAT(result.GetIssues(), IsEmpty()) - << " error: " << result.FormatError(); - - // Obtain the inferred return type of the expression `test`. - const Ast* ast = result.GetAst(); - ASSERT_THAT(ast, NotNull()); - cel::expr::CheckedExpr checked_expr; - ASSERT_THAT(cel::AstToCheckedExpr(*ast, &checked_expr), IsOk()); - auto it = checked_expr.type_map().find(checked_expr.expr().id()); - ASSERT_NE(it, checked_expr.type_map().end()); - - cel::expr::Type actual_type_pb = it->second; - EXPECT_THAT(actual_type_pb, EqualsProto(expected_type_pb)); -} - -std::vector GetTypeInfoTestCases() { - return { - TypeInfoTestCase{ - .type_info = {.name = "int"}, - .expected_type_pb = "primitive: INT64", - }, - TypeInfoTestCase{ - .type_info = {.name = "list", - .params = {Config::TypeInfo{.name = "int"}}}, - .expected_type_pb = "list_type { elem_type { primitive: INT64 } }", - }, - TypeInfoTestCase{ - .type_info = {.name = "list"}, - .expected_type_pb = "list_type { elem_type { dyn {} }}", - }, - TypeInfoTestCase{ - .type_info = {.name = "map", - .params = {Config::TypeInfo{.name = "string"}, - Config::TypeInfo{.name = "int"}}}, - .expected_type_pb = "map_type { key_type { primitive: STRING } " - "value_type { primitive: INT64 }}", - }, - TypeInfoTestCase{ - .type_info = {.name = "cel.expr.conformance.proto2.TestAllTypes"}, - .expected_type_pb = - "message_type: 'cel.expr.conformance.proto2.TestAllTypes'", - }, - TypeInfoTestCase{ - .type_info = {.name = "A", - .params = {Config::TypeInfo{.name = "B", - .is_type_param = true}}}, - // TypeParam is replaced with dyn by the type checker. - .expected_type_pb = - "abstract_type { name: 'A' parameter_types { dyn {} } }", - }, - TypeInfoTestCase{ - .type_info = {.name = "any"}, - .expected_type_pb = "well_known: ANY", - }, - TypeInfoTestCase{ - .type_info = {.name = "timestamp"}, - .expected_type_pb = "well_known: TIMESTAMP", - }, - TypeInfoTestCase{ - .type_info = {.name = "google.protobuf.DoubleValue"}, - .expected_type_pb = "wrapper: DOUBLE", - }, - TypeInfoTestCase{ - .type_info = {.name = "type", - .params = {Config::TypeInfo{.name = "duration"}}}, - .expected_type_pb = "type: { well_known: DURATION }", - }, - TypeInfoTestCase{ - .type_info = {.name = "parameterized", - .params = {{.name = "A", .is_type_param = true}, - {.name = "double"}}}, - // TypeParam is replaced with dyn by the type checker. - .expected_type_pb = "abstract_type { name: 'parameterized' " - "parameter_types { dyn {} } " - "parameter_types { primitive: DOUBLE } }", - }, - }; -} - -INSTANTIATE_TEST_SUITE_P(VariableConfigTest, TypeInfoTest, - ValuesIn(GetTypeInfoTestCases())); - struct VariableConfigWithValueTestCase { Config::VariableConfig variable_config; std::string validate_type_expr; diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 5e7c9631d..a6f66bd83 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -552,10 +552,13 @@ absl::Status ParseVariableConfigs(Config& config, absl::string_view yaml, variable_config.type_info = type_info; - if (constant_kind_case != ConstantKindCase::kUnspecified) { + if (constant_kind_case != ConstantKindCase::kUnspecified && + !value_str.empty()) { CEL_ASSIGN_OR_RETURN( variable_config.value, ParseConstantValue(yaml, value, constant_kind_case, value_str)); + } else if (constant_kind_case == ConstantKindCase::kNull) { + variable_config.value = Constant(nullptr); } CEL_RETURN_IF_ERROR(config.AddVariableConfig(variable_config)); diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index b34c25254..c3e4839af 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -226,8 +226,10 @@ TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { const Config::VariableConfig& variable_config = config.GetVariableConfigs()[0]; EXPECT_EQ(variable_config.name, "const"); - EXPECT_EQ(variable_config.type_info.name, param.type_name); - EXPECT_EQ(variable_config.value, param.expected_constant); + EXPECT_EQ(variable_config.type_info.name, param.type_name) + << " yaml: " << yaml; + EXPECT_EQ(variable_config.value, param.expected_constant) + << " yaml: " << yaml; } std::vector GetParseConstantTestCases() { diff --git a/env/type_info.cc b/env/type_info.cc new file mode 100644 index 000000000..ed72a842f --- /dev/null +++ b/env/type_info.cc @@ -0,0 +1,178 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/type_info.h" + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "env/config.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +std::optional TypeNameToTypeKind(absl::string_view type_name) { + // Excluded types: + // kUnknown + // kError + // kTypeParam + // kFunction + // kEnum + + static const absl::NoDestructor< + absl::flat_hash_map> + kTypeNameToTypeKind({ + {"null", TypeKind::kNull}, + {"bool", TypeKind::kBool}, + {"int", TypeKind::kInt}, + {"uint", TypeKind::kUint}, + {"double", TypeKind::kDouble}, + {"string", TypeKind::kString}, + {"bytes", TypeKind::kBytes}, + {"timestamp", TypeKind::kTimestamp}, + {TimestampType::kName, TypeKind::kTimestamp}, + {"duration", TypeKind::kDuration}, + {DurationType::kName, TypeKind::kDuration}, + {"list", TypeKind::kList}, + {"map", TypeKind::kMap}, + {"", TypeKind::kDyn}, + {"any", TypeKind::kAny}, + {"dyn", TypeKind::kDyn}, + {BoolWrapperType::kName, TypeKind::kBoolWrapper}, + {IntWrapperType::kName, TypeKind::kIntWrapper}, + {UintWrapperType::kName, TypeKind::kUintWrapper}, + {DoubleWrapperType::kName, TypeKind::kDoubleWrapper}, + {StringWrapperType::kName, TypeKind::kStringWrapper}, + {BytesWrapperType::kName, TypeKind::kBytesWrapper}, + {"type", TypeKind::kType}, + }); + if (auto it = kTypeNameToTypeKind->find(type_name); + it != kTypeNameToTypeKind->end()) { + return it->second; + } + + return std::nullopt; +} +} // namespace + +absl::StatusOr TypeInfoToType( + const Config::TypeInfo& type_info, + const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::Arena* arena) { + if (type_info.is_type_param) { + return TypeParamType(type_info.name); + } + + std::optional type_kind = TypeNameToTypeKind(type_info.name); + if (!type_kind.has_value()) { + if (type_info.params.empty() && descriptor_pool != nullptr) { + const google::protobuf::Descriptor* type = + descriptor_pool->FindMessageTypeByName(type_info.name); + if (type != nullptr) { + return Type::Message(type); + } + } + // TODO(uncreated-issue/88): use a TypeIntrospector to validate opaque types + std::vector parameter_types; + for (const Config::TypeInfo& param : type_info.params) { + CEL_ASSIGN_OR_RETURN(Type parameter_type, + TypeInfoToType(param, descriptor_pool, arena)); + parameter_types.push_back(parameter_type); + } + + return OpaqueType(arena, type_info.name, parameter_types); + } + + switch (*type_kind) { + case TypeKind::kNull: + return NullType(); + case TypeKind::kBool: + return BoolType(); + case TypeKind::kInt: + return IntType(); + case TypeKind::kUint: + return UintType(); + case TypeKind::kDouble: + return DoubleType(); + case TypeKind::kString: + return StringType(); + case TypeKind::kBytes: + return BytesType(); + case TypeKind::kDuration: + return DurationType(); + case TypeKind::kTimestamp: + return TimestampType(); + case TypeKind::kList: { + Type element_type; + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN( + element_type, + TypeInfoToType(type_info.params[0], descriptor_pool, arena)); + } else { + element_type = DynType(); + } + return ListType(arena, element_type); + } + case TypeKind::kMap: { + Type key_type = DynType(); + Type value_type = DynType(); + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN(key_type, TypeInfoToType(type_info.params[0], + descriptor_pool, arena)); + } + if (type_info.params.size() > 1) { + CEL_ASSIGN_OR_RETURN( + value_type, + TypeInfoToType(type_info.params[1], descriptor_pool, arena)); + } + return MapType(arena, key_type, value_type); + } + case TypeKind::kDyn: + return DynType(); + case TypeKind::kAny: + return AnyType(); + case TypeKind::kBoolWrapper: + return BoolWrapperType(); + case TypeKind::kIntWrapper: + return IntWrapperType(); + case TypeKind::kUintWrapper: + return UintWrapperType(); + case TypeKind::kDoubleWrapper: + return DoubleWrapperType(); + case TypeKind::kStringWrapper: + return StringWrapperType(); + case TypeKind::kBytesWrapper: + return BytesWrapperType(); + case TypeKind::kType: { + if (type_info.params.empty()) { + return TypeType(arena, DynType()); + } + CEL_ASSIGN_OR_RETURN(Type type, TypeInfoToType(type_info.params[0], + descriptor_pool, arena)); + return TypeType(arena, type); + } + default: + return DynType(); + } +} + +} // namespace cel diff --git a/env/type_info.h b/env/type_info.h new file mode 100644 index 000000000..bb3cfde43 --- /dev/null +++ b/env/type_info.h @@ -0,0 +1,35 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ +#define THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ + +#include "absl/status/statusor.h" +#include "common/type.h" +#include "env/config.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Converts a Config::TypeInfo to a cel::Type. Returns an error if the type_info +// cannot be converted to a known cel::Type, a list configured with more than +// one parameter. +absl::StatusOr TypeInfoToType( + const Config::TypeInfo& type_info, + const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::Arena* arena); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ diff --git a/env/type_info_test.cc b/env/type_info_test.cc new file mode 100644 index 000000000..ca9d0467c --- /dev/null +++ b/env/type_info_test.cc @@ -0,0 +1,127 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "env/type_info.h" + +#include +#include + +#include "common/type.h" +#include "common/type_proto.h" +#include "env/config.h" +#include "internal/proto_matchers.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/text_format.h" + +namespace cel { +namespace { + +using absl_testing::IsOk; +using testing::ValuesIn; + +struct TestCase { + Config::TypeInfo type_info; + std::string expected_type_pb; +}; + +using TypeInfoTest = testing::TestWithParam; + +TEST_P(TypeInfoTest, TypeInfo) { + const TestCase& param = GetParam(); + cel::expr::Type expected_type_pb; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(param.expected_type_pb, + &expected_type_pb)); + + google::protobuf::Arena arena; + const google::protobuf::DescriptorPool* descriptor_pool = + cel::internal::GetTestingDescriptorPool(); + ASSERT_OK_AND_ASSIGN( + cel::Type actual_type, + cel::TypeInfoToType(param.type_info, descriptor_pool, &arena)); + + cel::expr::Type actual_type_pb; + ASSERT_THAT(cel::TypeToProto(actual_type, &actual_type_pb), IsOk()); + EXPECT_THAT(actual_type_pb, + cel::internal::test::EqualsProto(expected_type_pb)); +} + +std::vector GetTestCases() { + return { + TestCase{ + .type_info = {.name = "int"}, + .expected_type_pb = "primitive: INT64", + }, + TestCase{ + .type_info = {.name = "list", + .params = {Config::TypeInfo{.name = "int"}}}, + .expected_type_pb = "list_type { elem_type { primitive: INT64 } }", + }, + TestCase{ + .type_info = {.name = "list"}, + .expected_type_pb = "list_type { elem_type { dyn {} }}", + }, + TestCase{ + .type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "int"}}}, + .expected_type_pb = "map_type { key_type { primitive: STRING } " + "value_type { primitive: INT64 }}", + }, + TestCase{ + .type_info = {.name = "cel.expr.conformance.proto2.TestAllTypes"}, + .expected_type_pb = + "message_type: 'cel.expr.conformance.proto2.TestAllTypes'", + }, + TestCase{ + .type_info = {.name = "A", + .params = {Config::TypeInfo{.name = "B", + .is_type_param = true}}}, + .expected_type_pb = + "abstract_type { name: 'A' parameter_types { type_param: 'B' } }", + }, + TestCase{ + .type_info = {.name = "any"}, + .expected_type_pb = "well_known: ANY", + }, + TestCase{ + .type_info = {.name = "timestamp"}, + .expected_type_pb = "well_known: TIMESTAMP", + }, + TestCase{ + .type_info = {.name = "google.protobuf.DoubleValue"}, + .expected_type_pb = "wrapper: DOUBLE", + }, + TestCase{ + .type_info = {.name = "type", + .params = {Config::TypeInfo{.name = "duration"}}}, + .expected_type_pb = "type: { well_known: DURATION }", + }, + TestCase{ + .type_info = {.name = "parameterized", + .params = {{.name = "A", .is_type_param = true}, + {.name = "double"}}}, + .expected_type_pb = "abstract_type { name: 'parameterized' " + "parameter_types { type_param: 'A' } " + "parameter_types { primitive: DOUBLE } }", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(TypeInfoTest, TypeInfoTest, ValuesIn(GetTestCases())); + +} // namespace +} // namespace cel From fb214306e1517a61220af3a5cf51395b49a0299c Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 31 Mar 2026 10:20:16 -0700 Subject: [PATCH 17/88] Remove support for hierarchical type_envs. This was not used anywhere and extended envs will likely need to deep copy. PiperOrigin-RevId: 892414339 --- checker/internal/type_check_env.cc | 91 ++++++++++++------------------ checker/internal/type_check_env.h | 20 +------ 2 files changed, 37 insertions(+), 74 deletions(-) diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index d856a7230..e76621435 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -34,25 +34,18 @@ namespace cel::checker_internal { const VariableDecl* absl_nullable TypeCheckEnv::LookupVariable( absl::string_view name) const { - const TypeCheckEnv* scope = this; - while (scope != nullptr) { - if (auto it = scope->variables_.find(name); it != scope->variables_.end()) { - return &it->second; - } - scope = scope->parent_; + if (auto it = variables_.find(name); it != variables_.end()) { + return &it->second; } return nullptr; } const FunctionDecl* absl_nullable TypeCheckEnv::LookupFunction( absl::string_view name) const { - const TypeCheckEnv* scope = this; - while (scope != nullptr) { - if (auto it = scope->functions_.find(name); it != scope->functions_.end()) { - return &it->second; - } - scope = scope->parent_; + if (auto it = functions_.find(name); it != functions_.end()) { + return &it->second; } + return nullptr; } @@ -71,17 +64,13 @@ absl::StatusOr> TypeCheckEnv::LookupTypeName( return Type::Enum(enum_descriptor); } } - const TypeCheckEnv* scope = this; - do { - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); - ++iter) { - auto type = (*iter)->FindType(name); - if (!type.ok() || type->has_value()) { - return type; - } + for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + ++iter) { + auto type = (*iter)->FindType(name); + if (!type.ok() || type->has_value()) { + return type; } - scope = scope->parent_; - } while ((scope != nullptr)); + } return absl::nullopt; } @@ -106,26 +95,21 @@ absl::StatusOr> TypeCheckEnv::LookupEnumConstant( return decl; } } - const TypeCheckEnv* scope = this; - do { - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); - ++iter) { - auto enum_constant = (*iter)->FindEnumConstant(type, value); - if (!enum_constant.ok()) { - return enum_constant.status(); - } - if (enum_constant->has_value()) { - auto decl = - MakeVariableDecl(absl::StrCat((**enum_constant).type_full_name, ".", - (**enum_constant).value_name), - (**enum_constant).type); - decl.set_value( - Constant(static_cast((**enum_constant).number))); - return decl; - } + for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + ++iter) { + auto enum_constant = (*iter)->FindEnumConstant(type, value); + if (!enum_constant.ok()) { + return enum_constant.status(); } - scope = scope->parent_; - } while (scope != nullptr); + if (enum_constant->has_value()) { + auto decl = + MakeVariableDecl(absl::StrCat((**enum_constant).type_full_name, ".", + (**enum_constant).value_name), + (**enum_constant).type); + decl.set_value(Constant(static_cast((**enum_constant).number))); + return decl; + } + } return absl::nullopt; } @@ -165,22 +149,17 @@ absl::StatusOr> TypeCheckEnv::LookupStructField( return cel::MessageTypeField(field_descriptor); } } - const TypeCheckEnv* scope = this; - do { - // Check the type providers in reverse registration order. - // Note: this doesn't allow for shadowing a type with a subset type of the - // same name -- the parent type provider will still be considered when - // checking field accesses. - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); - ++iter) { - auto field_info = - (*iter)->FindStructTypeFieldByName(type_name, field_name); - if (!field_info.ok() || field_info->has_value()) { - return field_info; - } + // Check the type providers in reverse registration order. + // Note: this doesn't allow for shadowing a type with a subset type of the + // same name -- the prior type provider will still be considered when + // checking field accesses. + for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + ++iter) { + auto field_info = (*iter)->FindStructTypeFieldByName(type_name, field_name); + if (!field_info.ok() || field_info->has_value()) { + return field_info; } - scope = scope->parent_; - } while (scope != nullptr); + } return absl::nullopt; } diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index a4d242fdf..5c8b3629c 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -89,17 +89,14 @@ class TypeCheckEnv { explicit TypeCheckEnv( absl_nonnull std::shared_ptr descriptor_pool) - : descriptor_pool_(std::move(descriptor_pool)), - container_(""), - parent_(nullptr) {} + : descriptor_pool_(std::move(descriptor_pool)), container_("") {} TypeCheckEnv(absl_nonnull std::shared_ptr descriptor_pool, std::shared_ptr arena) : descriptor_pool_(std::move(descriptor_pool)), arena_(std::move(arena)), - container_(""), - parent_(nullptr) {} + container_("") {} // Move-only. TypeCheckEnv(TypeCheckEnv&&) = default; @@ -163,9 +160,6 @@ class TypeCheckEnv { functions_[decl.name()] = std::move(decl); } - const TypeCheckEnv* absl_nullable parent() const { return parent_; } - void set_parent(TypeCheckEnv* parent) { parent_ = parent; } - // Returns the declaration for the given name if it is found in the current // or any parent scope. // Note: the returned declaration ptr is only valid as long as no changes are @@ -184,10 +178,6 @@ class TypeCheckEnv { absl::StatusOr> LookupTypeConstant( google::protobuf::Arena* absl_nonnull arena, absl::string_view type_name) const; - TypeCheckEnv MakeExtendedEnvironment() const ABSL_ATTRIBUTE_LIFETIME_BOUND { - return TypeCheckEnv(this); - } - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { return descriptor_pool_.get(); } @@ -200,11 +190,6 @@ class TypeCheckEnv { } private: - explicit TypeCheckEnv(const TypeCheckEnv* absl_nonnull parent) - : descriptor_pool_(parent->descriptor_pool_), - container_(parent != nullptr ? parent->container() : ""), - parent_(parent) {} - absl::StatusOr> LookupEnumConstant( absl::string_view type, absl::string_view value) const; @@ -212,7 +197,6 @@ class TypeCheckEnv { // If set, an arena was needed to allocate types in the environment. absl_nullable std::shared_ptr arena_; std::string container_; - const TypeCheckEnv* absl_nullable parent_; // Maps fully qualified names to declarations. absl::flat_hash_map variables_; From 463ddc0a6dd94979c7f3779ec6e8c68e7e89ad2b Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 31 Mar 2026 13:30:29 -0700 Subject: [PATCH 18/88] Delete types/interfaces related to TypeFactory and TypeManager. PiperOrigin-RevId: 892512872 --- common/BUILD | 3 -- common/type_factory.h | 30 ---------------- common/type_introspector.h | 2 -- common/type_manager.h | 57 ------------------------------ common/types/legacy_type_manager.h | 45 ----------------------- common/values/opaque_value.h | 2 +- eval/compiler/flat_expr_builder.h | 12 ------- eval/public/cel_type_registry.h | 2 +- 8 files changed, 2 insertions(+), 151 deletions(-) delete mode 100644 common/type_factory.h delete mode 100644 common/type_manager.h delete mode 100644 common/types/legacy_type_manager.h diff --git a/common/BUILD b/common/BUILD index da96b1c98..a4ac6a3ef 100644 --- a/common/BUILD +++ b/common/BUILD @@ -541,12 +541,9 @@ cc_library( ], ) + [ "type.h", - "type_factory.h", "type_introspector.h", - "type_manager.h", ], deps = [ - ":memory", ":type_kind", "//internal:string_pool", "@com_google_absl//absl/algorithm:container", diff --git a/common/type_factory.h b/common/type_factory.h deleted file mode 100644 index 33829ea8b..000000000 --- a/common/type_factory.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_FACTORY_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_FACTORY_H_ - -namespace cel { - -// `TypeFactory` is the preferred way for constructing compound types such as -// lists, maps, structs, and opaques. It caches types and avoids constructing -// them multiple times. -class TypeFactory { - public: - virtual ~TypeFactory() = default; -}; - -} // namespace cel - -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_FACTORY_H_ diff --git a/common/type_introspector.h b/common/type_introspector.h index 7f4a19a31..159e49ab4 100644 --- a/common/type_introspector.h +++ b/common/type_introspector.h @@ -24,8 +24,6 @@ namespace cel { -class TypeFactory; - // `TypeIntrospector` is an interface which allows querying type-related // information. It handles type introspection, but not type reflection. That is, // it is not capable of instantiating new values or understanding values. Its diff --git a/common/type_manager.h b/common/type_manager.h deleted file mode 100644 index 354f4c9b8..000000000 --- a/common/type_manager.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_MANAGER_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_MANAGER_H_ - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "common/memory.h" -#include "common/type.h" -#include "common/type_factory.h" -#include "common/type_introspector.h" - -namespace cel { - -// `TypeManager` is an additional layer on top of `TypeFactory` and -// `TypeIntrospector` which combines the two and adds additional functionality. -class TypeManager : public virtual TypeFactory { - public: - virtual ~TypeManager() = default; - - // See `TypeIntrospector::FindType`. - absl::StatusOr> FindType(absl::string_view name) { - return GetTypeIntrospector().FindType(name); - } - - // See `TypeIntrospector::FindStructTypeFieldByName`. - absl::StatusOr> FindStructTypeFieldByName( - absl::string_view type, absl::string_view name) { - return GetTypeIntrospector().FindStructTypeFieldByName(type, name); - } - - // See `TypeIntrospector::FindStructTypeFieldByName`. - absl::StatusOr> FindStructTypeFieldByName( - const StructType& type, absl::string_view name) { - return GetTypeIntrospector().FindStructTypeFieldByName(type, name); - } - - protected: - virtual const TypeIntrospector& GetTypeIntrospector() const = 0; -}; - -} // namespace cel - -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_MANAGER_H_ diff --git a/common/types/legacy_type_manager.h b/common/types/legacy_type_manager.h deleted file mode 100644 index 238335b52..000000000 --- a/common/types/legacy_type_manager.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_MANAGER_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_MANAGER_H_ - -#include "common/memory.h" -#include "common/type_introspector.h" -#include "common/type_manager.h" - -namespace cel::common_internal { - -// `LegacyTypeManager` is an implementation which should be used when -// converting between `cel::Value` and `google::api::expr::runtime::CelValue` -// and only then. -class LegacyTypeManager : public virtual TypeManager { - public: - explicit LegacyTypeManager(const TypeIntrospector& type_introspector) - : type_introspector_(type_introspector) {} - - protected: - const TypeIntrospector& GetTypeIntrospector() const final { - return type_introspector_; - } - - private: - const TypeIntrospector& type_introspector_; -}; - -} // namespace cel::common_internal - -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_LEGACY_TYPE_MANAGER_H_ diff --git a/common/values/opaque_value.h b/common/values/opaque_value.h index 273b7889a..57af78ae0 100644 --- a/common/values/opaque_value.h +++ b/common/values/opaque_value.h @@ -52,7 +52,7 @@ class Value; class OpaqueValueInterface; class OpaqueValueInterfaceIterator; class OpaqueValue; -class TypeFactory; + using OpaqueValueContent = CustomValueContent; struct OpaqueValueDispatcher { diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index eab1e7ff8..7d770b443 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -53,18 +53,6 @@ class FlatExprBuilder { type_registry_(env_->type_registry), use_legacy_type_provider_(use_legacy_type_provider) {} - FlatExprBuilder( - absl_nonnull std::shared_ptr env, - const cel::FunctionRegistry& function_registry, - const cel::TypeRegistry& type_registry, - const cel::RuntimeOptions& options, bool use_legacy_type_provider = false) - : env_(std::move(env)), - options_(options), - container_(options.container), - function_registry_(function_registry), - type_registry_(type_registry), - use_legacy_type_provider_(use_legacy_type_provider) {} - void AddAstTransform(std::unique_ptr transform) { ast_transforms_.push_back(std::move(transform)); } diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index 290726bfe..0c01eb8e9 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -86,7 +86,7 @@ class CelTypeRegistry { // registry. // // This is a composited type provider that should check in order: - // - builtins (via TypeManager) + // - builtins // - custom enumerations // - registered extension type providers in the order registered. const cel::TypeProvider& GetTypeProvider() const { From 8773590dc5fd54ac50963b78bb33cbb82043d28f Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 31 Mar 2026 18:05:30 -0700 Subject: [PATCH 19/88] Refactor cel::TypeIntrospector - simplify base class logic, add an explicit implementation for looking up WKTs - Make WellKnownType lookups explicit in runtime implementations. - delete extensions/protobuf/type_introspector and related. They should be unused now and don't work as expected with the type checker or runtime. PiperOrigin-RevId: 892637710 --- common/BUILD | 1 + common/type_introspector.cc | 55 +++++----- common/type_introspector.h | 51 ++++++++- .../thread_compatible_type_introspector.h | 34 ------ eval/compiler/flat_expr_builder.h | 2 - eval/public/structs/legacy_type_provider.cc | 8 ++ extensions/protobuf/BUILD | 36 ------ extensions/protobuf/type_introspector.cc | 80 -------------- extensions/protobuf/type_introspector.h | 58 ---------- extensions/protobuf/type_introspector_test.cc | 103 ------------------ extensions/protobuf/type_reflector.h | 41 ------- runtime/internal/BUILD | 1 + runtime/internal/runtime_type_provider.cc | 17 ++- 13 files changed, 96 insertions(+), 391 deletions(-) delete mode 100644 common/types/thread_compatible_type_introspector.h delete mode 100644 extensions/protobuf/type_introspector.cc delete mode 100644 extensions/protobuf/type_introspector.h delete mode 100644 extensions/protobuf/type_introspector_test.cc delete mode 100644 extensions/protobuf/type_reflector.h diff --git a/common/BUILD b/common/BUILD index a4ac6a3ef..8dd8921cc 100644 --- a/common/BUILD +++ b/common/BUILD @@ -548,6 +548,7 @@ cc_library( "//internal:string_pool", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/common/type_introspector.cc b/common/type_introspector.cc index c69235b3b..6d5158a2f 100644 --- a/common/type_introspector.cc +++ b/common/type_introspector.cc @@ -211,35 +211,6 @@ const WellKnownTypesMap& GetWellKnownTypesMap() { } // namespace -absl::StatusOr> TypeIntrospector::FindType( - absl::string_view name) const { - const auto& well_known_types = GetWellKnownTypesMap(); - if (auto it = well_known_types.find(name); it != well_known_types.end()) { - return it->second.type; - } - return FindTypeImpl(name); -} - -absl::StatusOr> -TypeIntrospector::FindEnumConstant(absl::string_view type, - absl::string_view value) const { - if (type == "google.protobuf.NullValue" && value == "NULL_VALUE") { - return EnumConstant{NullType{}, "google.protobuf.NullValue", "NULL_VALUE", - 0}; - } - return FindEnumConstantImpl(type, value); -} - -absl::StatusOr> -TypeIntrospector::FindStructTypeFieldByName(absl::string_view type, - absl::string_view name) const { - const auto& well_known_types = GetWellKnownTypesMap(); - if (auto it = well_known_types.find(type); it != well_known_types.end()) { - return it->second.FieldByName(name); - } - return FindStructTypeFieldByNameImpl(type, name); -} - absl::StatusOr> TypeIntrospector::FindTypeImpl( absl::string_view) const { return absl::nullopt; @@ -257,4 +228,30 @@ TypeIntrospector::FindStructTypeFieldByNameImpl(absl::string_view, return absl::nullopt; } +absl::optional FindWellKnownType(absl::string_view name) { + const auto& well_known_types = GetWellKnownTypesMap(); + if (auto it = well_known_types.find(name); it != well_known_types.end()) { + return it->second.type; + } + return absl::nullopt; +} + +absl::optional FindWellKnownTypeEnumConstant( + absl::string_view type, absl::string_view value) { + if (type == "google.protobuf.NullValue" && value == "NULL_VALUE") { + return TypeIntrospector::EnumConstant{ + NullType{}, "google.protobuf.NullValue", "NULL_VALUE", 0}; + } + return absl::nullopt; +} + +absl::optional FindWellKnownTypeFieldByName( + absl::string_view type, absl::string_view name) { + const auto& well_known_types = GetWellKnownTypesMap(); + if (auto it = well_known_types.find(type); it != well_known_types.end()) { + return it->second.FieldByName(name); + } + return absl::nullopt; +} + } // namespace cel diff --git a/common/type_introspector.h b/common/type_introspector.h index 159e49ab4..fb6ea09c1 100644 --- a/common/type_introspector.h +++ b/common/type_introspector.h @@ -43,17 +43,23 @@ class TypeIntrospector { virtual ~TypeIntrospector() = default; // `FindType` find the type corresponding to name `name`. - absl::StatusOr> FindType(absl::string_view name) const; + absl::StatusOr> FindType(absl::string_view name) const { + return FindTypeImpl(name); + } // `FindEnumConstant` find a fully qualified enumerator name `name` in enum // type `type`. absl::StatusOr> FindEnumConstant( - absl::string_view type, absl::string_view value) const; + absl::string_view type, absl::string_view value) const { + return FindEnumConstantImpl(type, value); + } // `FindStructTypeFieldByName` find the name, number, and type of the field // `name` in type `type`. absl::StatusOr> FindStructTypeFieldByName( - absl::string_view type, absl::string_view name) const; + absl::string_view type, absl::string_view name) const { + return FindStructTypeFieldByNameImpl(type, name); + } // `FindStructTypeFieldByName` find the name, number, and type of the field // `name` in struct type `type`. @@ -74,6 +80,45 @@ class TypeIntrospector { absl::string_view name) const; }; +// Looks up a well-known type by name. +absl::optional FindWellKnownType(absl::string_view name); + +// Looks up a well-known enum constant by type and value. +absl::optional FindWellKnownTypeEnumConstant( + absl::string_view type, absl::string_view value); + +// Looks up a well-known struct type field by type and field name. +absl::optional FindWellKnownTypeFieldByName( + absl::string_view type, absl::string_view name); + +// `WellKnownTypeIntrospector` is an implementation of `TypeIntrospector` which +// handles well known types that are treated specially by CEL. +// +// This also serves as a minimal implementation of a TypeInstrospector when no +// custom types are present. +// +// This class has no mutable state, so trivially thread-safe. +class WellKnownTypeIntrospector : public virtual TypeIntrospector { + public: + WellKnownTypeIntrospector() = default; + + private: + absl::StatusOr> FindTypeImpl( + absl::string_view name) const final { + return FindWellKnownType(name); + } + + absl::StatusOr> FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const final { + return FindWellKnownTypeEnumConstant(type, value); + } + + absl::StatusOr> FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const final { + return FindWellKnownTypeFieldByName(type, name); + } +}; + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_INTROSPECTOR_H_ diff --git a/common/types/thread_compatible_type_introspector.h b/common/types/thread_compatible_type_introspector.h deleted file mode 100644 index 870ea9054..000000000 --- a/common/types/thread_compatible_type_introspector.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private - -#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_THREAD_COMPATIBLE_TYPE_INTROSPECTOR_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_THREAD_COMPATIBLE_TYPE_INTROSPECTOR_H_ - -#include "common/type_introspector.h" - -namespace cel::common_internal { - -// `ThreadCompatibleTypeIntrospector` is a basic implementation of -// `TypeIntrospector` which is thread compatible. By default this implementation -// just returns `NOT_FOUND` for most methods. -class ThreadCompatibleTypeIntrospector : public virtual TypeIntrospector { - public: - ThreadCompatibleTypeIntrospector() = default; -}; - -} // namespace cel::common_internal - -#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_THREAD_COMPATIBLE_TYPE_INTROSPECTOR_H_ diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 7d770b443..aa4d0b4e5 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -23,12 +23,10 @@ #include #include "absl/base/nullability.h" -#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/ast.h" #include "base/type_provider.h" -#include "common/value.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/evaluator_core.h" #include "runtime/function_registry.h" diff --git a/eval/public/structs/legacy_type_provider.cc b/eval/public/structs/legacy_type_provider.cc index a85f08911..f87ab9645 100644 --- a/eval/public/structs/legacy_type_provider.cc +++ b/eval/public/structs/legacy_type_provider.cc @@ -27,6 +27,7 @@ #include "common/legacy_value.h" #include "common/memory.h" #include "common/type.h" +#include "common/type_introspector.h" #include "common/value.h" #include "eval/public/message_wrapper.h" #include "eval/public/structs/legacy_type_adapter.h" @@ -175,6 +176,9 @@ LegacyTypeProvider::NewValueBuilder( absl::StatusOr> LegacyTypeProvider::FindTypeImpl( absl::string_view name) const { + if (auto type = cel::FindWellKnownType(name); type.has_value()) { + return type; + } if (auto type_info = ProvideLegacyTypeInfo(name); type_info.has_value()) { const auto* descriptor = (*type_info)->GetDescriptor(MessageWrapper()); if (descriptor != nullptr) { @@ -189,6 +193,10 @@ absl::StatusOr> LegacyTypeProvider::FindTypeImpl( absl::StatusOr> LegacyTypeProvider::FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const { + if (auto result = cel::FindWellKnownTypeFieldByName(type, name); + result.has_value()) { + return result; + } if (auto type_info = ProvideLegacyTypeInfo(type); type_info.has_value()) { if (auto field_desc = (*type_info)->FindFieldByName(name); field_desc.has_value()) { diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index 6c3f654f9..3f4081b09 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -87,48 +87,12 @@ cc_library( ], ) -cc_library( - name = "type", - srcs = [ - "type_introspector.cc", - ], - hdrs = [ - "type_introspector.h", - ], - deps = [ - "//common:type", - "@com_google_absl//absl/base:nullability", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:optional", - "@com_google_protobuf//:protobuf", - ], -) - -cc_test( - name = "type_test", - srcs = [ - "type_introspector_test.cc", - ], - deps = [ - ":type", - "//common:type", - "//common:type_kind", - "//internal:testing", - "@com_google_absl//absl/types:optional", - "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", - "@com_google_protobuf//:protobuf", - ], -) - cc_library( name = "value", hdrs = [ - "type_reflector.h", "value.h", ], deps = [ - ":type", "//common:memory", "//common:type", "//common:value", diff --git a/extensions/protobuf/type_introspector.cc b/extensions/protobuf/type_introspector.cc deleted file mode 100644 index 8b445c359..000000000 --- a/extensions/protobuf/type_introspector.cc +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "extensions/protobuf/type_introspector.h" - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "common/type.h" -#include "common/type_introspector.h" -#include "google/protobuf/descriptor.h" - -namespace cel::extensions { - -absl::StatusOr> ProtoTypeIntrospector::FindTypeImpl( - absl::string_view name) const { - // We do not have to worry about well known types here. - // `TypeIntrospector::FindType` handles those directly. - const auto* desc = descriptor_pool()->FindMessageTypeByName(name); - if (desc == nullptr) { - return absl::nullopt; - } - return MessageType(desc); -} - -absl::StatusOr> -ProtoTypeIntrospector::FindEnumConstantImpl(absl::string_view type, - absl::string_view value) const { - const google::protobuf::EnumDescriptor* enum_desc = - descriptor_pool()->FindEnumTypeByName(type); - // google.protobuf.NullValue is special cased in the base class. - if (enum_desc == nullptr) { - return absl::nullopt; - } - - // Note: we don't support strong enum typing at this time so only the fully - // qualified enum values are meaningful, so we don't provide any signal if the - // enum type is found but can't match the value name. - const google::protobuf::EnumValueDescriptor* value_desc = - enum_desc->FindValueByName(value); - if (value_desc == nullptr) { - return absl::nullopt; - } - - return TypeIntrospector::EnumConstant{ - EnumType(enum_desc), enum_desc->full_name(), value_desc->name(), - value_desc->number()}; -} - -absl::StatusOr> -ProtoTypeIntrospector::FindStructTypeFieldByNameImpl( - absl::string_view type, absl::string_view name) const { - // We do not have to worry about well known types here. - // `TypeIntrospector::FindStructTypeFieldByName` handles those directly. - const auto* desc = descriptor_pool()->FindMessageTypeByName(type); - if (desc == nullptr) { - return absl::nullopt; - } - const auto* field_desc = desc->FindFieldByName(name); - if (field_desc == nullptr) { - field_desc = descriptor_pool()->FindExtensionByPrintableName(desc, name); - if (field_desc == nullptr) { - return absl::nullopt; - } - } - return MessageTypeField(field_desc); -} - -} // namespace cel::extensions diff --git a/extensions/protobuf/type_introspector.h b/extensions/protobuf/type_introspector.h deleted file mode 100644 index 5eb9c3ddc..000000000 --- a/extensions/protobuf/type_introspector.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_INTROSPECTOR_H_ -#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_INTROSPECTOR_H_ - -#include "absl/base/nullability.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "common/type.h" -#include "common/type_introspector.h" -#include "google/protobuf/descriptor.h" - -namespace cel::extensions { - -class ProtoTypeIntrospector : public virtual TypeIntrospector { - public: - ProtoTypeIntrospector() - : ProtoTypeIntrospector(google::protobuf::DescriptorPool::generated_pool()) {} - - explicit ProtoTypeIntrospector( - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) - : descriptor_pool_(descriptor_pool) {} - - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { - return descriptor_pool_; - } - - protected: - absl::StatusOr> FindTypeImpl( - absl::string_view name) const final; - - absl::StatusOr> - FindEnumConstantImpl(absl::string_view type, - absl::string_view value) const final; - - absl::StatusOr> FindStructTypeFieldByNameImpl( - absl::string_view type, absl::string_view name) const final; - - private: - const google::protobuf::DescriptorPool* absl_nonnull const descriptor_pool_; -}; - -} // namespace cel::extensions - -#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_INTROSPECTOR_H_ diff --git a/extensions/protobuf/type_introspector_test.cc b/extensions/protobuf/type_introspector_test.cc deleted file mode 100644 index 0a7b21524..000000000 --- a/extensions/protobuf/type_introspector_test.cc +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "extensions/protobuf/type_introspector.h" - -#include "absl/types/optional.h" -#include "common/type.h" -#include "common/type_kind.h" -#include "internal/testing.h" -#include "cel/expr/conformance/proto2/test_all_types.pb.h" -#include "google/protobuf/descriptor.h" - -namespace cel::extensions { -namespace { - -using ::absl_testing::IsOkAndHolds; -using ::cel::expr::conformance::proto2::TestAllTypes; -using ::testing::Eq; -using ::testing::Optional; - -TEST(ProtoTypeIntrospector, FindType) { - ProtoTypeIntrospector introspector; - EXPECT_THAT( - introspector.FindType(TestAllTypes::descriptor()->full_name()), - IsOkAndHolds(Optional(Eq(MessageType(TestAllTypes::GetDescriptor()))))); - EXPECT_THAT(introspector.FindType("type.that.does.not.Exist"), - IsOkAndHolds(Eq(absl::nullopt))); -} - -TEST(ProtoTypeIntrospector, FindStructTypeFieldByName) { - ProtoTypeIntrospector introspector; - ASSERT_OK_AND_ASSIGN( - auto field, introspector.FindStructTypeFieldByName( - TestAllTypes::descriptor()->full_name(), "single_int32")); - ASSERT_TRUE(field.has_value()); - EXPECT_THAT(field->name(), Eq("single_int32")); - EXPECT_THAT(field->number(), Eq(1)); - EXPECT_THAT( - introspector.FindStructTypeFieldByName( - TestAllTypes::descriptor()->full_name(), "field_that_does_not_exist"), - IsOkAndHolds(Eq(absl::nullopt))); - EXPECT_THAT(introspector.FindStructTypeFieldByName("type.that.does.not.Exist", - "does_not_matter"), - IsOkAndHolds(Eq(absl::nullopt))); -} - -TEST(ProtoTypeIntrospector, FindEnumConstant) { - ProtoTypeIntrospector introspector; - const auto* enum_desc = TestAllTypes::NestedEnum_descriptor(); - ASSERT_OK_AND_ASSIGN( - auto enum_constant, - introspector.FindEnumConstant( - "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", "BAZ")); - ASSERT_TRUE(enum_constant.has_value()); - EXPECT_EQ(enum_constant->type.kind(), TypeKind::kEnum); - EXPECT_EQ(enum_constant->type_full_name, enum_desc->full_name()); - EXPECT_EQ(enum_constant->value_name, "BAZ"); - EXPECT_EQ(enum_constant->number, 2); -} - -TEST(ProtoTypeIntrospector, FindEnumConstantNull) { - ProtoTypeIntrospector introspector; - ASSERT_OK_AND_ASSIGN( - auto enum_constant, - introspector.FindEnumConstant("google.protobuf.NullValue", "NULL_VALUE")); - ASSERT_TRUE(enum_constant.has_value()); - EXPECT_EQ(enum_constant->type.kind(), TypeKind::kNull); - EXPECT_EQ(enum_constant->type_full_name, "google.protobuf.NullValue"); - EXPECT_EQ(enum_constant->value_name, "NULL_VALUE"); - EXPECT_EQ(enum_constant->number, 0); -} - -TEST(ProtoTypeIntrospector, FindEnumConstantUnknownEnum) { - ProtoTypeIntrospector introspector; - - ASSERT_OK_AND_ASSIGN(auto enum_constant, - introspector.FindEnumConstant("NotARealEnum", "BAZ")); - EXPECT_FALSE(enum_constant.has_value()); -} - -TEST(ProtoTypeIntrospector, FindEnumConstantUnknownValue) { - ProtoTypeIntrospector introspector; - - ASSERT_OK_AND_ASSIGN( - auto enum_constant, - introspector.FindEnumConstant( - "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", "QUX")); - ASSERT_FALSE(enum_constant.has_value()); -} - -} // namespace -} // namespace cel::extensions diff --git a/extensions/protobuf/type_reflector.h b/extensions/protobuf/type_reflector.h deleted file mode 100644 index 4665235fe..000000000 --- a/extensions/protobuf/type_reflector.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_REFLECTOR_H_ -#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_REFLECTOR_H_ - -#include "absl/base/nullability.h" -#include "common/type_reflector.h" -#include "extensions/protobuf/type_introspector.h" -#include "google/protobuf/descriptor.h" - -namespace cel::extensions { - -class ProtoTypeReflector : public TypeReflector, public ProtoTypeIntrospector { - public: - ProtoTypeReflector() - : ProtoTypeReflector(google::protobuf::DescriptorPool::generated_pool()) {} - - explicit ProtoTypeReflector( - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) - : ProtoTypeIntrospector(descriptor_pool) {} - - const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const { - return ProtoTypeIntrospector::descriptor_pool(); - } -}; - -} // namespace cel::extensions - -#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_TYPE_REFLECTOR_H_ diff --git a/runtime/internal/BUILD b/runtime/internal/BUILD index 28f9bd1cb..1223ff6d1 100644 --- a/runtime/internal/BUILD +++ b/runtime/internal/BUILD @@ -195,6 +195,7 @@ cc_library( deps = [ "//common:type", "//common:value", + "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", diff --git a/runtime/internal/runtime_type_provider.cc b/runtime/internal/runtime_type_provider.cc index 1acb52223..40f5ff575 100644 --- a/runtime/internal/runtime_type_provider.cc +++ b/runtime/internal/runtime_type_provider.cc @@ -44,8 +44,10 @@ absl::Status RuntimeTypeProvider::RegisterType(const OpaqueType& type) { absl::StatusOr> RuntimeTypeProvider::FindTypeImpl( absl::string_view name) const { - // We do not have to worry about well known types here. - // `TypeIntrospector::FindType` handles those directly. + auto type = FindWellKnownType(name); + if (type.has_value()) { + return type; + } const auto* desc = descriptor_pool_->FindMessageTypeByName(name); if (desc != nullptr) { return MessageType(desc); @@ -60,9 +62,12 @@ absl::StatusOr> RuntimeTypeProvider::FindTypeImpl( absl::StatusOr> RuntimeTypeProvider::FindEnumConstantImpl(absl::string_view type, absl::string_view value) const { + auto enum_constant = FindWellKnownTypeEnumConstant(type, value); + if (enum_constant.has_value()) { + return enum_constant; + } const google::protobuf::EnumDescriptor* enum_desc = descriptor_pool_->FindEnumTypeByName(type); - // google.protobuf.NullValue is special cased in the base class. if (enum_desc == nullptr) { return absl::nullopt; } @@ -84,8 +89,10 @@ RuntimeTypeProvider::FindEnumConstantImpl(absl::string_view type, absl::StatusOr> RuntimeTypeProvider::FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const { - // We do not have to worry about well known types here. - // `TypeIntrospector::FindStructTypeFieldByName` handles those directly. + auto field = FindWellKnownTypeFieldByName(type, name); + if (field.has_value()) { + return field; + } const auto* desc = descriptor_pool_->FindMessageTypeByName(type); if (desc == nullptr) { return absl::nullopt; From 6e1cb5311aa17ddc2ed9e4fd30e042d4c592b359 Mon Sep 17 00:00:00 2001 From: Justin King Date: Wed, 1 Apr 2026 10:27:42 -0700 Subject: [PATCH 20/88] Fix compatibility with newer versions of protobuf PiperOrigin-RevId: 893000224 --- internal/well_known_types.cc | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index f66a9360b..c736be69f 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -71,6 +71,17 @@ using ::google::protobuf::util::TimeUtil; using CppStringType = ::google::protobuf::FieldDescriptor::CppStringType; +FieldDescriptor::Label GetFieldLabel( + const FieldDescriptor* absl_nonnull field) { + if (field->is_required()) { + return FieldDescriptor::LABEL_REQUIRED; + } else if (field->is_repeated()) { + return FieldDescriptor::LABEL_REPEATED; + } else { + return FieldDescriptor::LABEL_OPTIONAL; + } +} + absl::string_view FlatStringValue( const StringValue& value ABSL_ATTRIBUTE_LIFETIME_BOUND, std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { @@ -264,11 +275,11 @@ absl::string_view LabelToString(FieldDescriptor::Label label) { absl::Status CheckFieldCardinality(const FieldDescriptor* absl_nonnull field, FieldDescriptor::Label label) { - if (ABSL_PREDICT_FALSE(field->label() != label)) { - return absl::InvalidArgumentError( - absl::StrCat("unexpected field cardinality for protocol buffer message " - "well known type: ", - field->full_name(), " ", LabelToString(field->label()))); + if (ABSL_PREDICT_FALSE(GetFieldLabel(field) != label)) { + return absl::InvalidArgumentError(absl::StrCat( + "unexpected field cardinality for protocol buffer message " + "well known type: ", + field->full_name(), " ", LabelToString(GetFieldLabel(field)))); } return absl::OkStatus(); } From cf07a8775491c705739416adcb3a77f3cf5c0916 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 2 Apr 2026 10:14:53 -0700 Subject: [PATCH 21/88] Expose source_info position -> SourceLocation helper. PiperOrigin-RevId: 893571451 --- checker/internal/type_checker_impl.cc | 58 +++++---------------------- common/BUILD | 2 + common/ast.cc | 38 ++++++++++++++++++ common/ast.h | 8 ++++ common/ast_test.cc | 52 ++++++++++++++++++++++++ 5 files changed, 110 insertions(+), 48 deletions(-) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 1e9995b19..28cbf21e0 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -48,7 +48,6 @@ #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" -#include "common/source.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/status_macros.h" @@ -66,43 +65,6 @@ std::string FormatCandidate(absl::Span qualifiers) { return absl::StrJoin(qualifiers, "."); } -SourceLocation ComputeSourceLocation(const Ast& ast, int64_t expr_id) { - const auto& source_info = ast.source_info(); - auto iter = source_info.positions().find(expr_id); - if (iter == source_info.positions().end()) { - return SourceLocation{}; - } - int32_t absolute_position = iter->second; - if (absolute_position < 0) { - return SourceLocation{}; - } - - // Find the first line offset that is greater than the absolute position. - int32_t line_idx = -1; - int32_t offset = 0; - for (int32_t i = 0; i < source_info.line_offsets().size(); ++i) { - int32_t next_offset = source_info.line_offsets()[i]; - if (next_offset <= offset) { - // Line offset is not monotonically increasing, so line information is - // invalid. - return SourceLocation{}; - } - if (absolute_position < next_offset) { - line_idx = i; - break; - } - offset = next_offset; - } - - if (line_idx < 0 || line_idx >= source_info.line_offsets().size()) { - return SourceLocation{}; - } - - int32_t rel_position = absolute_position - offset; - - return SourceLocation{line_idx + 1, rel_position}; -} - // Flatten the type to the AST type representation to remove any lifecycle // dependency between the type check environment and the AST. // @@ -362,7 +324,7 @@ class ResolveVisitor : public AstVisitorBase { void ReportMissingReference(const Expr& expr, absl::string_view name) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("undeclared reference to '", name, "' (in container '", container_, "')"))); } @@ -370,7 +332,7 @@ class ResolveVisitor : public AstVisitorBase { void ReportUndefinedField(int64_t expr_id, absl::string_view field_name, absl::string_view struct_name) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr_id), + ast_->ComputeSourceLocation(expr_id), absl::StrCat("undefined field '", field_name, "' not found in struct '", struct_name, "'"))); } @@ -378,7 +340,7 @@ class ResolveVisitor : public AstVisitorBase { void ReportTypeMismatch(int64_t expr_id, const Type& expected, const Type& actual) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr_id), + ast_->ComputeSourceLocation(expr_id), absl::StrCat("expected type '", FormatTypeName(inference_context_->FinalizeType(expected)), "' but found '", @@ -408,7 +370,7 @@ class ResolveVisitor : public AstVisitorBase { } if (!inference_context_->IsAssignable(value_type, field_type)) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, field.id()), + ast_->ComputeSourceLocation(field.id()), absl::StrCat( "expected type of field '", field_info->name(), "' is '", FormatTypeName(inference_context_->FinalizeType(field_type)), @@ -553,7 +515,7 @@ void ResolveVisitor::PostVisitConst(const Expr& expr, break; default: ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("unsupported constant type: ", constant.kind().index()))); types_[&expr] = ErrorType(); @@ -605,7 +567,7 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { // To match the Go implementation, we just warn here, but in the future // we should consider making this an error. ReportIssue(TypeCheckIssue( - Severity::kWarning, ComputeSourceLocation(*ast_, key->id()), + Severity::kWarning, ast_->ComputeSourceLocation(key->id()), absl::StrCat( "unsupported map key type: ", FormatTypeName(inference_context_->FinalizeType(key_type))))); @@ -711,7 +673,7 @@ void ResolveVisitor::PostVisitStruct(const Expr& expr, if (resolved_type.kind() != TypeKind::kStruct && !IsWellKnownMessageType(resolved_name)) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("type '", resolved_name, "' does not support message creation"))); types_[&expr] = ErrorType(); @@ -862,7 +824,7 @@ void ResolveVisitor::PostVisitComprehensionSubexpression( break; default: ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, comprehension.iter_range().id()), + ast_->ComputeSourceLocation(comprehension.iter_range().id()), absl::StrCat( "expression of type '", FormatTypeName(inference_context_->FinalizeType(range_type)), @@ -933,7 +895,7 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, if (!resolution.has_value()) { ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, expr.id()), + ast_->ComputeSourceLocation(expr.id()), absl::StrCat("found no matching overload for '", decl.name(), "' applied to '(", absl::StrJoin(arg_types, ", ", @@ -1133,7 +1095,7 @@ absl::optional ResolveVisitor::CheckFieldType(int64_t id, } ReportIssue(TypeCheckIssue::CreateError( - ComputeSourceLocation(*ast_, id), + ast_->ComputeSourceLocation(id), absl::StrCat( "expression of type '", FormatTypeName(inference_context_->FinalizeType(operand_type)), diff --git a/common/BUILD b/common/BUILD index 8dd8921cc..e289ef413 100644 --- a/common/BUILD +++ b/common/BUILD @@ -25,6 +25,7 @@ cc_library( hdrs = ["ast.h"], deps = [ ":expr", + ":source", "//common/ast:metadata", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", @@ -39,6 +40,7 @@ cc_test( deps = [ ":ast", ":expr", + ":source", "//internal:testing", "@com_google_absl//absl/container:flat_hash_map", ], diff --git a/common/ast.cc b/common/ast.cc index aea153197..48b6f5e0b 100644 --- a/common/ast.cc +++ b/common/ast.cc @@ -19,6 +19,7 @@ #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "common/ast/metadata.h" +#include "common/source.h" namespace cel { namespace { @@ -57,4 +58,41 @@ const Reference* absl_nullable Ast::GetReference(int64_t expr_id) const { return &iter->second; } +SourceLocation Ast::ComputeSourceLocation(int64_t expr_id) const { + const auto& source_info = this->source_info(); + auto iter = source_info.positions().find(expr_id); + if (iter == source_info.positions().end()) { + return SourceLocation{}; + } + int32_t absolute_position = iter->second; + if (absolute_position < 0) { + return SourceLocation{}; + } + + // Find the first line offset that is greater than the absolute position. + int32_t line_idx = -1; + int32_t offset = 0; + for (int32_t i = 0; i < source_info.line_offsets().size(); ++i) { + int32_t next_offset = source_info.line_offsets()[i]; + if (next_offset <= offset) { + // Line offset is not monotonically increasing, so line information is + // invalid. + return SourceLocation{}; + } + if (absolute_position < next_offset) { + line_idx = i; + break; + } + offset = next_offset; + } + + if (line_idx < 0 || line_idx >= source_info.line_offsets().size()) { + return SourceLocation{}; + } + + int32_t rel_position = absolute_position - offset; + + return SourceLocation{line_idx + 1, rel_position}; +} + } // namespace cel diff --git a/common/ast.h b/common/ast.h index 1b07b9878..db336f52d 100644 --- a/common/ast.h +++ b/common/ast.h @@ -24,6 +24,7 @@ #include "absl/strings/string_view.h" #include "common/ast/metadata.h" // IWYU pragma: export #include "common/expr.h" +#include "common/source.h" namespace cel { @@ -135,6 +136,13 @@ class Ast final { expr_version_ = expr_version; } + // Computes the source location (line and column) for the given expression id + // from the source info (which stores absolute positions). + // + // Returns a default (empty) source location if the expression id is not found + // or the source info is not populated correctly. + SourceLocation ComputeSourceLocation(int64_t expr_id) const; + private: Expr root_expr_; SourceInfo source_info_; diff --git a/common/ast_test.cc b/common/ast_test.cc index 744b9e8d3..56e1bcd1e 100644 --- a/common/ast_test.cc +++ b/common/ast_test.cc @@ -18,6 +18,7 @@ #include "absl/container/flat_hash_map.h" #include "common/expr.h" +#include "common/source.h" #include "internal/testing.h" namespace cel { @@ -132,5 +133,56 @@ TEST(AstImpl, CheckedExprDeepCopy) { EXPECT_EQ(ast.source_info().syntax_version(), "1.0"); } +TEST(AstImpl, ComputeSourceLocation) { + SourceInfo source_info; + source_info.set_line_offsets({10, 20, 30}); + source_info.mutable_positions()[1] = 0; // Start of first line + source_info.mutable_positions()[2] = 5; // Middle of first line + source_info.mutable_positions()[3] = 10; // ... + source_info.mutable_positions()[4] = 15; + source_info.mutable_positions()[5] = 20; + source_info.mutable_positions()[6] = 25; + + Ast ast(Expr{}, std::move(source_info)); + + EXPECT_EQ(ast.ComputeSourceLocation(1), (SourceLocation{1, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(2), (SourceLocation{1, 5})); + EXPECT_EQ(ast.ComputeSourceLocation(3), (SourceLocation{2, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(4), (SourceLocation{2, 5})); + EXPECT_EQ(ast.ComputeSourceLocation(5), (SourceLocation{3, 0})); + EXPECT_EQ(ast.ComputeSourceLocation(6), (SourceLocation{3, 5})); +} + +TEST(AstImpl, ComputeSourceLocationFailures) { + SourceInfo source_info; + source_info.set_line_offsets({10, 20}); + source_info.mutable_positions()[1] = -1; // Negative position + source_info.mutable_positions()[2] = 25; // Beyond last line offset + // ID 3 is missing + + Ast ast; + ast.mutable_source_info() = std::move(source_info); + + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + EXPECT_EQ(ast.ComputeSourceLocation(2), SourceLocation{}); + EXPECT_EQ(ast.ComputeSourceLocation(3), SourceLocation{}); +} + +TEST(AstImpl, ComputeSourceLocationInvalidLineOffsets) { + { + // Empty line offsets + Ast ast; + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + } + { + // Non-monotonic + SourceInfo source_info; + source_info.set_line_offsets({10, 5}); + source_info.mutable_positions()[1] = 12; + Ast ast(Expr{}, std::move(source_info)); + EXPECT_EQ(ast.ComputeSourceLocation(1), SourceLocation{}); + } +} + } // namespace } // namespace cel From 6bf474b38bcd754488e44ef30f9c61031ee47b37 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 2 Apr 2026 10:57:13 -0700 Subject: [PATCH 22/88] No public description PiperOrigin-RevId: 893593905 --- common/optional_ref.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/common/optional_ref.h b/common/optional_ref.h index 454926c80..c7ba580fc 100644 --- a/common/optional_ref.h +++ b/common/optional_ref.h @@ -84,7 +84,12 @@ class optional_ref final { constexpr T& value() const { return ABSL_PREDICT_TRUE(has_value()) ? *value_ - : (absl::optional().value(), *value_); + // Replicate the same error logic as in `absl::optional`'s + // `value()`. It either throws an exception or aborts the + // program. We intentionally ignore the return value of + // the constructed optional's value as we only need to run + // the code for error checking. + : ((void)absl::optional().value(), *value_); } constexpr T& operator*() const { From 74e7666b01c14dad63052d7e0b6c2a7673161e8d Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 2 Apr 2026 13:46:30 -0700 Subject: [PATCH 23/88] Initial implementation for Ast Validator Adds a new `Validator` type for applying semantic checks on a compiled expresion. Adds timestamp and duration literal validators as examples for `Validations`. PiperOrigin-RevId: 893679645 --- checker/validation_result.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/checker/validation_result.h b/checker/validation_result.h index 45f949739..8c84a84da 100644 --- a/checker/validation_result.h +++ b/checker/validation_result.h @@ -58,6 +58,8 @@ class ValidationResult { absl::Span GetIssues() const { return issues_; } + void AddIssue(TypeCheckIssue issue) { issues_.push_back(std::move(issue)); } + // The source expression may optionally be set if it is available. const cel::Source* absl_nullable GetSource() const { return source_.get(); } From 082088649a2af3342a6fbdcb8f338d4c6557dc70 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 2 Apr 2026 15:30:04 -0700 Subject: [PATCH 24/88] Add proto type introspector to C++ type checker. Intended to mediate field lookups so the checker follows any options consistently. Adds support for resolving fields via JSON name. PiperOrigin-RevId: 893734042 --- checker/BUILD | 1 + checker/checker_options.h | 8 + checker/internal/BUILD | 36 ++- .../descriptor_pool_type_introspector.cc | 245 ++++++++++++++++++ .../descriptor_pool_type_introspector.h | 105 ++++++++ .../descriptor_pool_type_introspector_test.cc | 175 +++++++++++++ checker/internal/test_ast_helpers.cc | 22 +- checker/internal/type_check_env.cc | 88 ++----- checker/internal/type_check_env.h | 27 +- checker/internal/type_checker_builder_impl.cc | 57 +++- .../type_checker_builder_impl_test.cc | 1 - checker/internal/type_checker_impl.cc | 6 + checker/type_checker_builder_factory_test.cc | 108 ++++++++ common/type_introspector.cc | 24 +- common/type_introspector.h | 33 +++ internal/BUILD | 1 + testutil/BUILD | 6 + testutil/test_json_names.proto | 31 +++ 18 files changed, 869 insertions(+), 105 deletions(-) create mode 100644 checker/internal/descriptor_pool_type_introspector.cc create mode 100644 checker/internal/descriptor_pool_type_introspector.h create mode 100644 checker/internal/descriptor_pool_type_introspector_test.cc create mode 100644 testutil/test_json_names.proto diff --git a/checker/BUILD b/checker/BUILD index 42e37e81d..d5eb3601c 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -128,6 +128,7 @@ cc_test( ":type_checker_builder_factory", ":validation_result", "//checker/internal:test_ast_helpers", + "//common:ast", "//common:decl", "//common:type", "//internal:status_macros", diff --git a/checker/checker_options.h b/checker/checker_options.h index 0b6d1af7f..cb85337fa 100644 --- a/checker/checker_options.h +++ b/checker/checker_options.h @@ -95,6 +95,14 @@ struct CheckerOptions { // Temporary flag to allow rolling out the change. No functional changes to // evaluation behavior in either mode. bool enable_function_name_in_reference = true; + + // If true, the checker will use the proto json field names for protobuf + // messages. Unlike protojson parsers, it will not accept the standard proto + // field names as valid json field names. + // + // Note: The checked AST will contain the json field names and an extension + // tag, but will require runtime support for resolving the json field names. + bool use_json_field_names = false; }; } // namespace cel diff --git a/checker/internal/BUILD b/checker/internal/BUILD index c539a2cc9..3f64417a0 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -27,10 +27,11 @@ cc_library( hdrs = ["test_ast_helpers.h"], deps = [ "//common:ast", - "//extensions/protobuf:ast_converters", "//internal:status_macros", "//parser", "//parser:options", + "//parser:parser_interface", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], @@ -64,6 +65,7 @@ cc_library( srcs = ["type_check_env.cc"], hdrs = ["type_check_env.h"], deps = [ + ":descriptor_pool_type_introspector", "//common:constant", "//common:decl", "//common:type", @@ -118,6 +120,7 @@ cc_library( "type_checker_impl.h", ], deps = [ + ":descriptor_pool_type_introspector", ":format_type_name", ":namespace_generator", ":type_check_env", @@ -261,14 +264,35 @@ cc_library( ], ) +cc_library( + name = "descriptor_pool_type_introspector", + srcs = ["descriptor_pool_type_introspector.cc"], + hdrs = ["descriptor_pool_type_introspector.h"], + deps = [ + "//common:type", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( - name = "format_type_name_test", - srcs = ["format_type_name_test.cc"], + name = "descriptor_pool_type_introspector_test", + srcs = ["descriptor_pool_type_introspector_test.cc"], deps = [ - ":format_type_name", + ":descriptor_pool_type_introspector", "//common:type", "//internal:testing", - "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", - "@com_google_protobuf//:protobuf", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", ], ) diff --git a/checker/internal/descriptor_pool_type_introspector.cc b/checker/internal/descriptor_pool_type_introspector.cc new file mode 100644 index 000000000..f6001e947 --- /dev/null +++ b/checker/internal/descriptor_pool_type_introspector.cc @@ -0,0 +1,245 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/descriptor_pool_type_introspector.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { +namespace { + +// Standard implementation for field lookups. +// Avoids building a FieldTable and just checks the DescriptorPool directly. +absl::StatusOr> +FindStructTypeFieldByNameDirectly( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type, absl::string_view name) { + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool->FindMessageTypeByName(type); + if (descriptor == nullptr) { + return absl::nullopt; + } + const google::protobuf::FieldDescriptor* absl_nullable field = + descriptor->FindFieldByName(name); + if (field != nullptr) { + return StructTypeField(MessageTypeField(field)); + } + + field = descriptor_pool->FindExtensionByPrintableName(descriptor, name); + if (field != nullptr) { + return StructTypeField(MessageTypeField(field)); + } + return absl::nullopt; +} + +// Standard implementation for listing fields. +// Avoids building a FieldTable and just checks the DescriptorPool directly. +absl::StatusOr< + absl::optional>> +ListStructTypeFieldsDirectly( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type) { + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool->FindMessageTypeByName(type); + if (descriptor == nullptr) { + return absl::nullopt; + } + + std::vector extensions; + descriptor_pool->FindAllExtensions(descriptor, &extensions); + + std::vector fields; + fields.reserve(descriptor->field_count() + extensions.size()); + + for (int i = 0; i < descriptor->field_count(); ++i) { + const google::protobuf::FieldDescriptor* field = descriptor->field(i); + fields.push_back({field->name(), StructTypeField(MessageTypeField(field))}); + } + + return fields; +} + +} // namespace + +using Field = DescriptorPoolTypeIntrospector::Field; + +absl::StatusOr> +DescriptorPoolTypeIntrospector::FindTypeImpl(absl::string_view name) const { + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool_->FindMessageTypeByName(name); + if (descriptor != nullptr) { + return Type::Message(descriptor); + } + const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = + descriptor_pool_->FindEnumTypeByName(name); + if (enum_descriptor != nullptr) { + return Type::Enum(enum_descriptor); + } + return absl::nullopt; +} + +absl::StatusOr> +DescriptorPoolTypeIntrospector::FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const { + const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = + descriptor_pool_->FindEnumTypeByName(type); + if (enum_descriptor != nullptr) { + const google::protobuf::EnumValueDescriptor* absl_nullable enum_value_descriptor = + enum_descriptor->FindValueByName(value); + if (enum_value_descriptor == nullptr) { + return absl::nullopt; + } + return EnumConstant{ + .type = Type::Enum(enum_descriptor), + .type_full_name = enum_descriptor->full_name(), + .value_name = enum_value_descriptor->name(), + .number = enum_value_descriptor->number(), + }; + } + return absl::nullopt; +} + +absl::StatusOr> +DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const { + if (!use_json_name_) { + return FindStructTypeFieldByNameDirectly(descriptor_pool_, type, name); + } + + const FieldTable* field_table = GetFieldTable(type); + + if (field_table == nullptr) { + return absl::nullopt; + } + + if (auto it = field_table->json_name_map.find(name); + it != field_table->json_name_map.end()) { + return field_table->fields[it->second].field; + } + + if (auto it = field_table->extension_name_map.find(name); + it != field_table->extension_name_map.end()) { + return field_table->fields[it->second].field; + } + + return absl::nullopt; +} + +absl::StatusOr< + absl::optional>> +DescriptorPoolTypeIntrospector::ListFieldsForStructTypeImpl( + absl::string_view type) const { + if (!use_json_name_) { + return ListStructTypeFieldsDirectly(descriptor_pool_, type); + } + + const FieldTable* field_table = GetFieldTable(type); + if (field_table == nullptr) { + return absl::nullopt; + } + std::vector fields; + fields.reserve(field_table->non_extensions.size()); + for (const auto& field : field_table->non_extensions) { + fields.push_back({field.json_name, field.field}); + } + return fields; +} + +const DescriptorPoolTypeIntrospector::FieldTable* +DescriptorPoolTypeIntrospector::GetFieldTable( + absl::string_view type_name) const { + absl::MutexLock lock(mu_); + if (auto it = field_tables_.find(type_name); it != field_tables_.end()) { + return it->second.get(); + } + if (cel::IsWellKnownMessageType(type_name)) { + return nullptr; + } + const google::protobuf::Descriptor* absl_nullable descriptor = + descriptor_pool_->FindMessageTypeByName(type_name); + if (descriptor == nullptr) { + return nullptr; + } + absl::string_view stable_type_name = descriptor->full_name(); + ABSL_DCHECK(stable_type_name == type_name); + std::unique_ptr field_table = CreateFieldTable(descriptor); + const FieldTable* field_table_ptr = field_table.get(); + field_tables_[stable_type_name] = std::move(field_table); + return field_table_ptr; +} + +std::unique_ptr +DescriptorPoolTypeIntrospector::CreateFieldTable( + const google::protobuf::Descriptor* absl_nonnull descriptor) const { + ABSL_DCHECK(!IsWellKnownMessageType(descriptor)); + std::vector fields; + absl::flat_hash_map json_name_map; + absl::flat_hash_map field_name_map; + absl::flat_hash_map extension_name_map; + + std::vector extensions; + descriptor_pool_->FindAllExtensions(descriptor, &extensions); + fields.reserve(descriptor->field_count() + extensions.size()); + + for (int i = 0; i < descriptor->field_count(); i++) { + const google::protobuf::FieldDescriptor* field = descriptor->field(i); + fields.push_back(Field{ + .field = StructTypeField(MessageTypeField(field)), + .json_name = field->json_name(), + .is_extension = false, + }); + field_name_map[field->name()] = fields.size() - 1; + if (use_json_name_ && !field->json_name().empty()) { + json_name_map[field->json_name()] = fields.size() - 1; + } + } + int non_extension_count = fields.size(); + + for (const google::protobuf::FieldDescriptor* extension : extensions) { + fields.push_back(Field{ + .field = StructTypeField(MessageTypeField(extension)), + .json_name = "", + .is_extension = true, + }); + extension_name_map[extension->full_name()] = fields.size() - 1; + } + int extension_count = fields.size() - non_extension_count; + auto result = std::make_unique(); + result->descriptor = descriptor; + result->fields = std::move(fields); + result->non_extensions = + absl::MakeConstSpan(result->fields).subspan(0, non_extension_count); + result->extensions = absl::MakeConstSpan(result->fields) + .subspan(non_extension_count, extension_count); + result->json_name_map = std::move(json_name_map); + result->field_name_map = std::move(field_name_map); + result->extension_name_map = std::move(extension_name_map); + return result; +} + +} // namespace cel::checker_internal diff --git a/checker/internal/descriptor_pool_type_introspector.h b/checker/internal/descriptor_pool_type_introspector.h new file mode 100644 index 000000000..8a970ea00 --- /dev/null +++ b/checker/internal/descriptor_pool_type_introspector.h @@ -0,0 +1,105 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Implementation of `TypeIntrospector` that uses a `google::protobuf::DescriptorPool`. +// +// This is used by the type checker to resolve protobuf types and their fields +// and apply any options like using JSON names. +// +// Neither copyable nor movable. Should be managed by a TypeCheckEnv. +class DescriptorPoolTypeIntrospector : public TypeIntrospector { + public: + struct Field { + StructTypeField field; + absl::string_view json_name; + bool is_extension = false; + }; + + DescriptorPoolTypeIntrospector() = delete; + explicit DescriptorPoolTypeIntrospector( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) + : descriptor_pool_(descriptor_pool) {} + + DescriptorPoolTypeIntrospector(const DescriptorPoolTypeIntrospector&) = + delete; + DescriptorPoolTypeIntrospector& operator=( + const DescriptorPoolTypeIntrospector&) = delete; + DescriptorPoolTypeIntrospector(DescriptorPoolTypeIntrospector&&) = delete; + DescriptorPoolTypeIntrospector& operator=(DescriptorPoolTypeIntrospector&&) = + delete; + + void set_use_json_name(bool use_json_name) { use_json_name_ = use_json_name; } + + bool use_json_name() const { return use_json_name_; } + + private: + struct FieldTable { + const google::protobuf::Descriptor* absl_nonnull descriptor; + std::vector fields; + absl::Span non_extensions; + absl::Span extensions; + absl::flat_hash_map json_name_map; + absl::flat_hash_map field_name_map; + absl::flat_hash_map extension_name_map; + }; + + absl::StatusOr> FindTypeImpl( + absl::string_view name) const final; + + absl::StatusOr> FindEnumConstantImpl( + absl::string_view type, absl::string_view value) const final; + + absl::StatusOr> FindStructTypeFieldByNameImpl( + absl::string_view type, absl::string_view name) const final; + + absl::StatusOr>> + ListFieldsForStructTypeImpl(absl::string_view type) const final; + + std::unique_ptr CreateFieldTable( + const google::protobuf::Descriptor* absl_nonnull descriptor) const; + + const FieldTable* GetFieldTable(absl::string_view type_name) const; + + // Cached map of type to field table. + mutable absl::flat_hash_map> + field_tables_ ABSL_GUARDED_BY(mu_); + + mutable absl::Mutex mu_; + bool use_json_name_ = false; + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_DESCRIPTOR_POOL_TYPE_INTROSPECTOR_H_ diff --git a/checker/internal/descriptor_pool_type_introspector_test.cc b/checker/internal/descriptor_pool_type_introspector_test.cc new file mode 100644 index 000000000..e2fdc9d40 --- /dev/null +++ b/checker/internal/descriptor_pool_type_introspector_test.cc @@ -0,0 +1,175 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/descriptor_pool_type_introspector.h" + +#include + +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "common/type.h" +#include "common/type_introspector.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::testing::AllOf; +using ::testing::Contains; +using ::testing::Eq; +using ::testing::Not; +using ::testing::Optional; +using ::testing::Property; +using ::testing::SizeIs; +using ::testing::Truly; + +TEST(DescriptorPoolTypeIntrospectorTest, FindType) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + EXPECT_THAT(introspector.FindType("cel.expr.conformance.proto3.TestAllTypes"), + IsOkAndHolds(Optional(Property(&Type::IsMessage, true)))); + EXPECT_THAT(introspector.FindType( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"), + IsOkAndHolds(Optional(Property(&Type::IsEnum, true)))); + EXPECT_THAT(introspector.FindType("non.existent.Type"), + IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindEnumConstant) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + auto result = introspector.FindEnumConstant( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum", "FOO"); + ASSERT_THAT(result, IsOkAndHolds(Optional(AllOf( + Truly([](const TypeIntrospector::EnumConstant& v) { + return v.value_name == "FOO" && v.number == 0; + }))))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindStructTypeFieldByName) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "single_int64"); + introspector.set_use_json_name(false); + + ASSERT_THAT(field, + IsOkAndHolds(Optional(Property(&StructTypeField::GetType, + Property(&Type::IsInt, true))))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, + FindStructTypeFieldByNameJsonNameIgnored) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(false); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); + + EXPECT_THAT(field, IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindExtension) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto2.TestAllTypes", + "cel.expr.conformance.proto2.int32_ext"); + + ASSERT_THAT(field, + IsOkAndHolds(Optional(Property(&StructTypeField::GetType, + Property(&Type::IsInt, true))))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, FindStructTypeFieldByNameWithJsonOpt) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(true); + + auto field = introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "single_int64"); + + ASSERT_THAT(field, IsOkAndHolds(Eq(absl::nullopt))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, + FindStructTypeFieldByNameWithJsonNameOpt) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(true); + + absl::StatusOr> field = + introspector.FindStructTypeFieldByName( + "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); + + ASSERT_THAT(field, + IsOkAndHolds(Optional(Property(&StructTypeField::GetType, + Property(&Type::IsInt, true))))); +} + +MATCHER_P(FieldListingIs, field_name, "") { return arg.name == field_name; } + +TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructType) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + absl::StatusOr< + absl::optional>> + fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(260)))); + EXPECT_THAT(*fields, Optional(Contains(FieldListingIs("single_int64")))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructTypeExtensions) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + auto fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto2.TestAllTypes"); + ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(259)))); + EXPECT_THAT(**fields, Contains(FieldListingIs("single_int64"))); + EXPECT_THAT( + **fields, + Not(Contains(FieldListingIs("cel.expr.conformance.proto2.int32_ext")))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, + ListFieldsForStructTypeWithJsonNameOpt) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + introspector.set_use_json_name(true); + auto fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(260)))); + EXPECT_THAT(**fields, Contains(FieldListingIs("singleInt64"))); + EXPECT_THAT(**fields, Not(Contains(FieldListingIs("single_int64")))); +} + +TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructTypeNotFound) { + DescriptorPoolTypeIntrospector introspector( + internal::GetTestingDescriptorPool()); + auto fields = introspector.ListFieldsForStructType( + "cel.expr.conformance.proto3.SomeOtherType"); + EXPECT_THAT(fields, IsOkAndHolds(Eq(absl::nullopt))); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/test_ast_helpers.cc b/checker/internal/test_ast_helpers.cc index 6ef7c2c05..543f70a89 100644 --- a/checker/internal/test_ast_helpers.cc +++ b/checker/internal/test_ast_helpers.cc @@ -14,29 +14,31 @@ #include "checker/internal/test_ast_helpers.h" #include -#include +#include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "common/ast.h" -#include "extensions/protobuf/ast_converters.h" #include "internal/status_macros.h" #include "parser/options.h" #include "parser/parser.h" +#include "parser/parser_interface.h" namespace cel::checker_internal { -using ::cel::extensions::CreateAstFromParsedExpr; -using ::google::api::expr::parser::Parse; - absl::StatusOr> MakeTestParsedAst( absl::string_view expression) { - static ParserOptions options; - options.enable_optional_syntax = true; - CEL_ASSIGN_OR_RETURN(auto parsed, - Parse(expression, /*description=*/expression, options)); + static const cel::Parser* parser = []() { + cel::ParserOptions options = {.enable_optional_syntax = true}; + auto parser = NewParserBuilder(options)->Build(); + ABSL_CHECK_OK(parser); + return parser->release(); + }(); - return CreateAstFromParsedExpr(std::move(parsed)); + CEL_ASSIGN_OR_RETURN( + auto source, + cel::NewSource(expression, /*description=*/std::string(expression))); + return parser->Parse(*source); } } // namespace cel::checker_internal diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index e76621435..c080326cb 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -28,7 +28,6 @@ #include "common/type_introspector.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" -#include "google/protobuf/descriptor.h" namespace cel::checker_internal { @@ -51,23 +50,10 @@ const FunctionDecl* absl_nullable TypeCheckEnv::LookupFunction( absl::StatusOr> TypeCheckEnv::LookupTypeName( absl::string_view name) const { - { - // Check the descriptor pool first, then fallback to custom type providers. - const google::protobuf::Descriptor* absl_nullable descriptor = - descriptor_pool_->FindMessageTypeByName(name); - if (descriptor != nullptr) { - return Type::Message(descriptor); - } - const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = - descriptor_pool_->FindEnumTypeByName(name); - if (enum_descriptor != nullptr) { - return Type::Enum(enum_descriptor); - } - } - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { - auto type = (*iter)->FindType(name); - if (!type.ok() || type->has_value()) { + CEL_ASSIGN_OR_RETURN(auto type, (*iter)->FindType(name)); + if (type.has_value()) { return type; } } @@ -76,37 +62,15 @@ absl::StatusOr> TypeCheckEnv::LookupTypeName( absl::StatusOr> TypeCheckEnv::LookupEnumConstant( absl::string_view type, absl::string_view value) const { - { - // Check the descriptor pool first, then fallback to custom type providers. - const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = - descriptor_pool_->FindEnumTypeByName(type); - if (enum_descriptor != nullptr) { - const google::protobuf::EnumValueDescriptor* absl_nullable enum_value_descriptor = - enum_descriptor->FindValueByName(value); - if (enum_value_descriptor == nullptr) { - return absl::nullopt; - } - auto decl = - MakeVariableDecl(absl::StrCat(enum_descriptor->full_name(), ".", - enum_value_descriptor->name()), - Type::Enum(enum_descriptor)); - decl.set_value( - Constant(static_cast(enum_value_descriptor->number()))); - return decl; - } - } - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { - auto enum_constant = (*iter)->FindEnumConstant(type, value); - if (!enum_constant.ok()) { - return enum_constant.status(); - } - if (enum_constant->has_value()) { - auto decl = - MakeVariableDecl(absl::StrCat((**enum_constant).type_full_name, ".", - (**enum_constant).value_name), - (**enum_constant).type); - decl.set_value(Constant(static_cast((**enum_constant).number))); + CEL_ASSIGN_OR_RETURN(auto enum_constant, + (*iter)->FindEnumConstant(type, value)); + if (enum_constant.has_value()) { + auto decl = MakeVariableDecl(absl::StrCat(enum_constant->type_full_name, + ".", enum_constant->value_name), + enum_constant->type); + decl.set_value(Constant(static_cast(enum_constant->number))); return decl; } } @@ -132,32 +96,16 @@ absl::StatusOr> TypeCheckEnv::LookupTypeConstant( absl::StatusOr> TypeCheckEnv::LookupStructField( absl::string_view type_name, absl::string_view field_name) const { - { - // Check the descriptor pool first, then fallback to custom type providers. - const google::protobuf::Descriptor* absl_nullable descriptor = - descriptor_pool_->FindMessageTypeByName(type_name); - if (descriptor != nullptr) { - const google::protobuf::FieldDescriptor* absl_nullable field_descriptor = - descriptor->FindFieldByName(field_name); - if (field_descriptor == nullptr) { - field_descriptor = descriptor_pool_->FindExtensionByPrintableName( - descriptor, field_name); - if (field_descriptor == nullptr) { - return absl::nullopt; - } - } - return cel::MessageTypeField(field_descriptor); - } - } - // Check the type providers in reverse registration order. + // Check the type providers in registration order. // Note: this doesn't allow for shadowing a type with a subset type of the - // same name -- the prior type provider will still be considered when + // same name -- the later type provider will still be considered when // checking field accesses. - for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { - auto field_info = (*iter)->FindStructTypeFieldByName(type_name, field_name); - if (!field_info.ok() || field_info->has_value()) { - return field_info; + CEL_ASSIGN_OR_RETURN( + auto field, (*iter)->FindStructTypeFieldByName(type_name, field_name)); + if (field.has_value()) { + return field; } } return absl::nullopt; diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 5c8b3629c..520b0eab6 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -28,6 +28,7 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "checker/internal/descriptor_pool_type_introspector.h" #include "common/constant.h" #include "common/decl.h" #include "common/type.h" @@ -89,14 +90,15 @@ class TypeCheckEnv { explicit TypeCheckEnv( absl_nonnull std::shared_ptr descriptor_pool) - : descriptor_pool_(std::move(descriptor_pool)), container_("") {} - - TypeCheckEnv(absl_nonnull std::shared_ptr - descriptor_pool, - std::shared_ptr arena) : descriptor_pool_(std::move(descriptor_pool)), - arena_(std::move(arena)), - container_("") {} + container_(""), + proto_type_introspector_( + std::make_shared( + descriptor_pool_.get())) { + type_providers_.push_back( + std::make_shared()); + type_providers_.push_back(proto_type_introspector_); + } // Move-only. TypeCheckEnv(TypeCheckEnv&&) = default; @@ -108,6 +110,13 @@ class TypeCheckEnv { container_ = std::move(container); } + const DescriptorPoolTypeIntrospector& proto_type_introspector() const { + return *proto_type_introspector_; + } + DescriptorPoolTypeIntrospector& proto_type_introspector() { + return *proto_type_introspector_; + } + void set_expected_type(const Type& type) { expected_type_ = std::move(type); } const absl::optional& expected_type() const { return expected_type_; } @@ -194,10 +203,14 @@ class TypeCheckEnv { absl::string_view type, absl::string_view value) const; absl_nonnull std::shared_ptr descriptor_pool_; + // If set, an arena was needed to allocate types in the environment. absl_nullable std::shared_ptr arena_; std::string container_; + // Used to resolve fields on message types. + std::shared_ptr proto_type_introspector_; + // Maps fully qualified names to declarations. absl::flat_hash_map variables_; absl::flat_hash_map functions_; diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 8aa5177a5..7545aa949 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -84,19 +84,55 @@ absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) { return absl::OkStatus(); } +absl::Status AddWellKnownContextDeclarationVariables( + const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env, + bool use_json_name) { + for (int i = 0; i < descriptor->field_count(); ++i) { + const google::protobuf::FieldDescriptor* field = descriptor->field(i); + Type type = MessageTypeField(field).GetType(); + if (type.IsEnum()) { + type = IntType(); + } + absl::string_view name = field->name(); + if (use_json_name) { + name = field->json_name(); + } + if (!env.InsertVariableIfAbsent(MakeVariableDecl(name, type))) { + return absl::AlreadyExistsError( + absl::StrCat("variable '", name, + "' declared multiple times (from context declaration: '", + descriptor->full_name(), "')")); + } + } + return absl::OkStatus(); +} + absl::Status AddContextDeclarationVariables( const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env) { - for (int i = 0; i < descriptor->field_count(); i++) { - const google::protobuf::FieldDescriptor* proto_field = descriptor->field(i); - MessageTypeField cel_field(proto_field); - Type field_type = cel_field.GetType(); - if (field_type.IsEnum()) { - field_type = IntType(); + const bool use_json_name = env.proto_type_introspector().use_json_name(); + if (IsWellKnownMessageType(descriptor)) { + return AddWellKnownContextDeclarationVariables(descriptor, env, + use_json_name); + } + CEL_ASSIGN_OR_RETURN(auto fields, + env.proto_type_introspector().ListFieldsForStructType( + descriptor->full_name())); + if (!fields.has_value()) { + return absl::InternalError(absl::StrCat("context declaration '", + descriptor->full_name(), + "' not found, but was expected")); + } + for (const auto& field_entry : *fields) { + Type type = field_entry.field.GetType(); + if (type.IsEnum()) { + type = IntType(); } - if (!env.InsertVariableIfAbsent( - MakeVariableDecl(cel_field.name(), field_type))) { + + absl::string_view name = field_entry.name; + + if (!env.InsertVariableIfAbsent(MakeVariableDecl(name, type))) { return absl::AlreadyExistsError( - absl::StrCat("variable '", cel_field.name(), + absl::StrCat("variable '", name, "' declared multiple times (from context declaration: '", descriptor->full_name(), "')")); } @@ -324,6 +360,9 @@ absl::StatusOr> TypeCheckerBuilderImpl::Build() { CEL_RETURN_IF_ERROR(BuildLibraryConfig(library, config)); } + env.proto_type_introspector().set_use_json_name( + options_.use_json_field_names); + for (const ConfigRecord& config : configs) { TypeCheckerSubset* subset = nullptr; if (!config.id.empty()) { diff --git a/checker/internal/type_checker_builder_impl_test.cc b/checker/internal/type_checker_builder_impl_test.cc index e23c26165..f7a3dff97 100644 --- a/checker/internal/type_checker_builder_impl_test.cc +++ b/checker/internal/type_checker_builder_impl_test.cc @@ -42,7 +42,6 @@ namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; - struct ContextDeclsTestCase { std::string expr; TypeSpec expected_type; diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 28cbf21e0..8e8047755 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -1319,6 +1319,12 @@ absl::StatusOr TypeCheckerImpl::Check( CEL_RETURN_IF_ERROR(rewriter.status()); ast->set_is_checked(true); + if (options_.use_json_field_names) { + ast->mutable_source_info().mutable_extensions().push_back( + cel::ExtensionSpec("json_name", + std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime})); + } return ValidationResult(std::move(ast), std::move(issues)); } diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc index a15d2e173..030186c83 100644 --- a/checker/type_checker_builder_factory_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -27,6 +27,7 @@ #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/validation_result.h" +#include "common/ast.h" #include "common/decl.h" #include "common/type.h" #include "internal/status_macros.h" @@ -496,6 +497,113 @@ TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationInt64Value) { ASSERT_TRUE(result.IsValid()); } +TEST(TypeCheckerBuilderTest, ContextDeclarationWithJsonName) { + CheckerOptions options; + options.use_json_field_names = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + + ASSERT_THAT(builder->AddContextDeclaration("cel.cpp.testutil.TestJsonNames"), + IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( + R"cel(int32_snake_case_json_name == 1 && + int64CamelCaseJsonName == 2 && + uint32DefaultJsonName == 3u && + // `uint64-custom-json-name` == 4u && + single_string == 'shadows' && + singleString == 'shadowed')cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->GetReturnType(), TypeSpec(PrimitiveType::kBool)); + EXPECT_THAT( + checked_ast->source_info().extensions(), + ElementsAre(cel::ExtensionSpec( + "json_name", std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime}))); +} + +TEST(TypeCheckerBuilderTest, JsonFieldNameOptionStructCreation) { + CheckerOptions options; + options.use_json_field_names = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( + R"cel(cel.cpp.testutil.TestJsonNames{ + int32_snake_case_json_name: 1, + int64CamelCaseJsonName: 2, + uint32DefaultJsonName: 3u, + `uint64-custom-json-name`: 4u, + single_string: 'shadows', + singleString: 'shadowed' + })cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()); + + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->GetReturnType(), + TypeSpec(MessageTypeSpec("cel.cpp.testutil.TestJsonNames"))); + EXPECT_THAT( + checked_ast->source_info().extensions(), + ElementsAre(cel::ExtensionSpec( + "json_name", std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime}))); +} + +TEST(TypeCheckerBuilderTest, JsonFieldNameOptionFieldAccess) { + CheckerOptions options; + options.use_json_field_names = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT( + builder->AddVariable(MakeVariableDecl( + "jsonObj", + cel::MessageType(builder->descriptor_pool()->FindMessageTypeByName( + "cel.cpp.testutil.TestJsonNames")))), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst( + R"cel( + jsonObj.int32_snake_case_json_name == 1 && + jsonObj.int64CamelCaseJsonName == 2 && + jsonObj.uint32DefaultJsonName == 3u && + jsonObj.`uint64-custom-json-name` == 4u && + jsonObj.single_string == 'shadows' && + jsonObj.singleString == 'shadowed' && + jsonObj.`cel.cpp.testutil.int32_snake_case_ext` == 5 && + jsonObj.`cel.cpp.testutil.int64CamelCaseExt` == 6 + )cel")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->GetReturnType(), TypeSpec(PrimitiveType::kBool)); + EXPECT_THAT( + checked_ast->source_info().extensions(), + ElementsAre(cel::ExtensionSpec( + "json_name", std::make_unique(1, 1), + {cel::ExtensionSpec::Component::kRuntime}))); +} + TEST(TypeCheckerBuilderTest, AddLibraryRedeclaredError) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, diff --git a/common/type_introspector.cc b/common/type_introspector.cc index 6d5158a2f..26f53685e 100644 --- a/common/type_introspector.cc +++ b/common/type_introspector.cc @@ -17,7 +17,9 @@ #include #include #include +#include +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" @@ -173,7 +175,8 @@ const WellKnownTypesMap& GetWellKnownTypesMap() { "google.protobuf.Value", WellKnownType{ DynType{}, - {MakeBasicStructTypeField("null_value", NullType{}, 1), + {// NullValue enum is an int. Not normally referenced directly. + MakeBasicStructTypeField("null_value", IntType{}, 1), MakeBasicStructTypeField("number_value", DoubleType{}, 2), MakeBasicStructTypeField("string_value", StringType{}, 3), MakeBasicStructTypeField("bool_value", BoolType{}, 4), @@ -228,6 +231,12 @@ TypeIntrospector::FindStructTypeFieldByNameImpl(absl::string_view, return absl::nullopt; } +absl::StatusOr< + absl::optional>> +TypeIntrospector::ListFieldsForStructTypeImpl(absl::string_view) const { + return absl::nullopt; +} + absl::optional FindWellKnownType(absl::string_view name) { const auto& well_known_types = GetWellKnownTypesMap(); if (auto it = well_known_types.find(name); it != well_known_types.end()) { @@ -240,7 +249,7 @@ absl::optional FindWellKnownTypeEnumConstant( absl::string_view type, absl::string_view value) { if (type == "google.protobuf.NullValue" && value == "NULL_VALUE") { return TypeIntrospector::EnumConstant{ - NullType{}, "google.protobuf.NullValue", "NULL_VALUE", 0}; + IntType{}, "google.protobuf.NullValue", "NULL_VALUE", 0}; } return absl::nullopt; } @@ -254,4 +263,15 @@ absl::optional FindWellKnownTypeFieldByName( return absl::nullopt; } +absl::optional> +ListFieldsForWellKnownType(absl::string_view type) { + const auto& well_known_types = GetWellKnownTypesMap(); + auto it = well_known_types.find(type); + if (it == well_known_types.end()) { + return absl::nullopt; + } + // The fields are not normally gettable. + return {}; +} + } // namespace cel diff --git a/common/type_introspector.h b/common/type_introspector.h index fb6ea09c1..932fb108e 100644 --- a/common/type_introspector.h +++ b/common/type_introspector.h @@ -16,6 +16,7 @@ #define THIRD_PARTY_CEL_CPP_COMMON_TYPE_INTROSPECTOR_H_ #include +#include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -40,6 +41,15 @@ class TypeIntrospector { int32_t number; }; + struct StructTypeFieldListing { + // The name used to access the field in source CEL. + // This is assumed owned by the TypeIntrospector or a dependency that + // outlives it. + absl::string_view name; + // The field description. + StructTypeField field; + }; + virtual ~TypeIntrospector() = default; // `FindType` find the type corresponding to name `name`. @@ -61,6 +71,18 @@ class TypeIntrospector { return FindStructTypeFieldByNameImpl(type, name); } + // `ListFieldsForStructType` returns the fields of struct type `type`. + // + // This is used when the struct is declared as a context type. + // + // If the type is not found, returns `absl::nullopt`. + // If the type exists but is not a struct or has no fields, returns an empty + // vector. + absl::StatusOr>> + ListFieldsForStructType(absl::string_view type) const { + return ListFieldsForStructTypeImpl(type); + } + // `FindStructTypeFieldByName` find the name, number, and type of the field // `name` in struct type `type`. absl::StatusOr> FindStructTypeFieldByName( @@ -78,6 +100,9 @@ class TypeIntrospector { virtual absl::StatusOr> FindStructTypeFieldByNameImpl(absl::string_view type, absl::string_view name) const; + + virtual absl::StatusOr>> + ListFieldsForStructTypeImpl(absl::string_view type) const; }; // Looks up a well-known type by name. @@ -91,6 +116,9 @@ absl::optional FindWellKnownTypeEnumConstant( absl::optional FindWellKnownTypeFieldByName( absl::string_view type, absl::string_view name); +absl::optional> +ListFieldsForWellKnownType(absl::string_view type); + // `WellKnownTypeIntrospector` is an implementation of `TypeIntrospector` which // handles well known types that are treated specially by CEL. // @@ -117,6 +145,11 @@ class WellKnownTypeIntrospector : public virtual TypeIntrospector { absl::string_view type, absl::string_view name) const final { return FindWellKnownTypeFieldByName(type, name); } + + absl::StatusOr>> + ListFieldsForStructTypeImpl(absl::string_view type) const final { + return ListFieldsForWellKnownType(type); + } }; } // namespace cel diff --git a/internal/BUILD b/internal/BUILD index 59f68df9b..6bd0f0a46 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -523,6 +523,7 @@ cel_proto_transitive_descriptor_set( deps = [ "//eval/testutil:test_extensions_proto", "//eval/testutil:test_message_proto", + "//testutil:test_json_names_proto", "@com_google_cel_spec//proto/cel/expr:checked_proto", "@com_google_cel_spec//proto/cel/expr:expr_proto", "@com_google_cel_spec//proto/cel/expr:syntax_proto", diff --git a/testutil/BUILD b/testutil/BUILD index 3f1aa4fe8..292696033 100644 --- a/testutil/BUILD +++ b/testutil/BUILD @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") @@ -86,3 +87,8 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +proto_library( + name = "test_json_names_proto", + srcs = ["test_json_names.proto"], +) diff --git a/testutil/test_json_names.proto b/testutil/test_json_names.proto new file mode 100644 index 000000000..a9551085b --- /dev/null +++ b/testutil/test_json_names.proto @@ -0,0 +1,31 @@ +edition = "2024"; + +package cel.cpp.testutil; + +option features.enforce_naming_style = STYLE_LEGACY; + +// This proto tests json_name options +message TestJsonNames { + int32 int32_snake_case_json_name = 1 + [json_name = "int32_snake_case_json_name"]; + int64 int64_camel_case_json_name = 2 [json_name = "int64CamelCaseJsonName"]; + uint32 uint32_default_json_name = 3; + uint64 uint64_custom_json_name = 4 [json_name = "uint64-custom-json-name"]; + + // Collides with normal field name. + string string_json_name_shadows = 5 [json_name = "single_string"]; + string single_string = 6; + + // protoc should fail on cases like these + // double double_json_shadow_default = 7 [json_name = "doubleJsonDefault"] + // double double_json_default = 8; + // double double_json_swapped_a = 7 [json_name = "double_json_swapped_b"]; + // double double_json_swapped_b = 8 [json_name = "double_json_swapped_a"]; + + extensions 100 to 199; +} + +extend TestJsonNames { + int32 int32_snake_case_ext = 100; + int64 int64CamelCaseExt = 101; +} From 5ef8a8f30dcbf2b891866273f2d74ef53fdaa763 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 7 Apr 2026 10:53:00 -0700 Subject: [PATCH 25/88] Add `Validator` and update Compiler to support. Adds a `Validator` type for applying post type-check validations. Add a timestamp/duration literal validator as in cel-java. Adds a Validator to the cel::Compiler to optionally add post-compilation validations to run after type-checking. PiperOrigin-RevId: 895985690 --- compiler/BUILD | 3 + compiler/compiler.h | 5 + compiler/compiler_factory.cc | 18 ++- compiler/compiler_factory_test.cc | 18 +++ validator/BUILD | 89 +++++++++++ validator/timestamp_literal_validator.cc | 134 ++++++++++++++++ validator/timestamp_literal_validator.h | 29 ++++ validator/timestamp_literal_validator_test.cc | 146 +++++++++++++++++ validator/validator.cc | 85 ++++++++++ validator/validator.h | 151 ++++++++++++++++++ validator/validator_test.cc | 85 ++++++++++ 11 files changed, 760 insertions(+), 3 deletions(-) create mode 100644 validator/BUILD create mode 100644 validator/timestamp_literal_validator.cc create mode 100644 validator/timestamp_literal_validator.h create mode 100644 validator/timestamp_literal_validator_test.cc create mode 100644 validator/validator.cc create mode 100644 validator/validator.h create mode 100644 validator/validator_test.cc diff --git a/compiler/BUILD b/compiler/BUILD index 02bbb37dd..44ef4f537 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -27,6 +27,7 @@ cc_library( "//checker:validation_result", "//parser:options", "//parser:parser_interface", + "//validator", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -48,6 +49,7 @@ cc_library( "//internal:status_macros", "//parser", "//parser:parser_interface", + "//validator", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -78,6 +80,7 @@ cc_test( "//parser:macro", "//parser:parser_interface", "//testutil:baseline_tests", + "//validator:timestamp_literal_validator", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", diff --git a/compiler/compiler.h b/compiler/compiler.h index 8b867cd60..6178cf2dc 100644 --- a/compiler/compiler.h +++ b/compiler/compiler.h @@ -28,6 +28,7 @@ #include "checker/validation_result.h" #include "parser/options.h" #include "parser/parser_interface.h" +#include "validator/validator.h" namespace cel { @@ -109,6 +110,7 @@ class CompilerBuilder { virtual TypeCheckerBuilder& GetCheckerBuilder() = 0; virtual ParserBuilder& GetParserBuilder() = 0; + virtual Validator& GetValidator() = 0; virtual absl::StatusOr> Build() = 0; }; @@ -135,6 +137,9 @@ class Compiler { // Accessor for the underlying parser. virtual const Parser& GetParser() const = 0; + + // Accessor for the underlying validator. + virtual const Validator& GetValidator() const = 0; }; } // namespace cel diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc index 6530dd816..c83633f68 100644 --- a/compiler/compiler_factory.cc +++ b/compiler/compiler_factory.cc @@ -32,6 +32,7 @@ #include "internal/status_macros.h" #include "parser/parser.h" #include "parser/parser_interface.h" +#include "validator/validator.h" #include "google/protobuf/descriptor.h" namespace cel { @@ -41,8 +42,12 @@ namespace { class CompilerImpl : public Compiler { public: CompilerImpl(std::unique_ptr type_checker, - std::unique_ptr parser) - : type_checker_(std::move(type_checker)), parser_(std::move(parser)) {} + std::unique_ptr parser, + // Copy the validator in case builder is reused. + Validator validator) + : type_checker_(std::move(type_checker)), + parser_(std::move(parser)), + validator_(std::move(validator)) {} absl::StatusOr Compile( absl::string_view expression, @@ -54,15 +59,20 @@ class CompilerImpl : public Compiler { type_checker_->Check(std::move(ast))); result.SetSource(std::move(source)); + if (!validator_.validations().empty()) { + validator_.UpdateValidationResult(result); + } return result; } const TypeChecker& GetTypeChecker() const override { return *type_checker_; } const Parser& GetParser() const override { return *parser_; } + const Validator& GetValidator() const override { return validator_; } private: std::unique_ptr type_checker_; std::unique_ptr parser_; + Validator validator_; }; class CompilerBuilderImpl : public CompilerBuilder { @@ -126,17 +136,19 @@ class CompilerBuilderImpl : public CompilerBuilder { TypeCheckerBuilder& GetCheckerBuilder() override { return *type_checker_builder_; } + Validator& GetValidator() override { return validator_; } absl::StatusOr> Build() override { CEL_ASSIGN_OR_RETURN(auto parser, parser_builder_->Build()); CEL_ASSIGN_OR_RETURN(auto type_checker, type_checker_builder_->Build()); return std::make_unique(std::move(type_checker), - std::move(parser)); + std::move(parser), validator_); } private: std::unique_ptr type_checker_builder_; std::unique_ptr parser_builder_; + Validator validator_; absl::flat_hash_set library_ids_; absl::flat_hash_set subsets_; diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc index 5df0f4794..cfdc68e26 100644 --- a/compiler/compiler_factory_test.cc +++ b/compiler/compiler_factory_test.cc @@ -35,6 +35,7 @@ #include "parser/macro.h" #include "parser/parser_interface.h" #include "testutil/baseline_tests.h" +#include "validator/timestamp_literal_validator.h" #include "google/protobuf/descriptor.h" namespace cel { @@ -287,6 +288,23 @@ TEST(CompilerFactoryTest, DisableStandardMacrosWithStdlib) { EXPECT_TRUE(result.IsValid()); } +TEST(CompilerFactoryTest, AddValidator) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + builder->GetValidator().AddValidation(TimestampLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("timestamp('invalid')")); + EXPECT_FALSE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(result, + compiler->Compile("timestamp('2024-01-01T00:00:00Z')")); + EXPECT_TRUE(result.IsValid()); +} + TEST(CompilerFactoryTest, FailsIfLibraryAddedTwice) { ASSERT_OK_AND_ASSIGN( auto builder, diff --git a/validator/BUILD b/validator/BUILD new file mode 100644 index 000000000..65f7fd6b3 --- /dev/null +++ b/validator/BUILD @@ -0,0 +1,89 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "validator", + srcs = ["validator.cc"], + hdrs = ["validator.h"], + deps = [ + "//checker:type_check_issue", + "//checker:validation_result", + "//common:ast", + "//common:navigable_ast", + "//common:source", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "validator_test", + srcs = ["validator_test.cc"], + deps = [ + ":validator", + "//checker:type_check_issue", + "//common:ast", + "//common:expr", + "//common:source", + "//internal:testing", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "timestamp_literal_validator_test", + srcs = ["timestamp_literal_validator_test.cc"], + deps = [ + ":timestamp_literal_validator", + ":validator", + "//checker:validation_result", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "timestamp_literal_validator", + srcs = ["timestamp_literal_validator.cc"], + hdrs = ["timestamp_literal_validator.h"], + deps = [ + ":validator", + "//common:constant", + "//common:navigable_ast", + "//common:standard_definitions", + "//internal:time", + "//tools:navigable_ast", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + ], +) + +licenses(["notice"]) diff --git a/validator/timestamp_literal_validator.cc b/validator/timestamp_literal_validator.cc new file mode 100644 index 000000000..8b9b76ebb --- /dev/null +++ b/validator/timestamp_literal_validator.cc @@ -0,0 +1,134 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/timestamp_literal_validator.h" + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "common/constant.h" +#include "common/navigable_ast.h" +#include "common/standard_definitions.h" +#include "internal/time.h" +#include "tools/navigable_ast.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +bool ValidateTimestamps(ValidationContext& context) { + bool valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() != NodeKind::kCall || + node.expr()->call_expr().function() != StandardFunctions::kTimestamp) { + continue; + } + if (node.children().size() != 1) { + // Checker should have already reported an error. + continue; + } + const NavigableAstNode& child = *node.children()[0]; + if (child.node_kind() != NodeKind::kConstant) { + // Not a literal, so nothing to do. + continue; + } + absl::Time ts; + const Constant& constant = child.expr()->const_expr(); + if (constant.has_string_value()) { + absl::string_view timestamp_str = + child.expr()->const_expr().string_value(); + if (!absl::ParseTime(absl::RFC3339_full, timestamp_str, &ts, nullptr)) { + context.ReportErrorAt(child.expr()->id(), "invalid timestamp literal"); + valid = false; + continue; + } + } else if (constant.has_int_value()) { + ts = absl::FromUnixSeconds(constant.int_value()); + } else { + // Checker should have already reported an error. + continue; + } + + if (absl::Status status = internal::ValidateTimestamp(ts); !status.ok()) { + context.ReportErrorAt( + child.expr()->id(), + absl::StrCat("invalid timestamp literal: ", status.message())); + valid = false; + } + } + + return valid; +} + +bool ValidateDurations(ValidationContext& context) { + bool valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() != NodeKind::kCall || + node.expr()->call_expr().function() != StandardFunctions::kDuration) { + continue; + } + if (node.children().size() != 1) { + // Checker should have already reported an error. + continue; + } + const NavigableAstNode& child = *node.children()[0]; + if (child.node_kind() != NodeKind::kConstant) { + // Not a literal, so nothing to do. + continue; + } + const Constant& constant = child.expr()->const_expr(); + if (!constant.has_string_value()) { + continue; + } + absl::Duration duration; + + absl::string_view duration_str = child.expr()->const_expr().string_value(); + if (!absl::ParseDuration(duration_str, &duration)) { + context.ReportErrorAt(child.expr()->id(), "invalid duration literal"); + valid = false; + continue; + } + + if (absl::Status status = internal::ValidateDuration(duration); + !status.ok()) { + context.ReportErrorAt( + child.expr()->id(), + absl::StrCat("invalid duration literal: ", status.message())); + valid = false; + } + } + + return valid; +} + +} // namespace + +const Validation& TimestampLiteralValidator() { + static const absl::NoDestructor kInstance( + ValidateTimestamps, "cel.validator.timestamp"); + return *kInstance; +} + +// Returns a validator that checks duration literals. +const Validation& DurationLiteralValidator() { + static const absl::NoDestructor kInstance( + ValidateDurations, "cel.validator.duration"); + return *kInstance; +} + +} // namespace cel diff --git a/validator/timestamp_literal_validator.h b/validator/timestamp_literal_validator.h new file mode 100644 index 000000000..6d2a39318 --- /dev/null +++ b/validator/timestamp_literal_validator.h @@ -0,0 +1,29 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ + +#include "validator/validator.h" +namespace cel { + +// Returns a `Validation` that checks timestamp literals are valid for CEL. +const Validation& TimestampLiteralValidator(); + +// Returns a `Validation` that checks duration literals are valid for CEL. +const Validation& DurationLiteralValidator(); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_TIMESTAMP_LITERAL_VALIDATOR_H_ diff --git a/validator/timestamp_literal_validator_test.cc b/validator/timestamp_literal_validator_test.cc new file mode 100644 index 000000000..136f7d645 --- /dev/null +++ b/validator/timestamp_literal_validator_test.cc @@ -0,0 +1,146 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/timestamp_literal_validator.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/validation_result.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + auto builder = + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool()).value(); + builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); + return builder->Build(); +} + +class TimestampLiteralValidatorTest : public ::testing::Test { + protected: + TimestampLiteralValidatorTest() { + validator_.AddValidation(TimestampLiteralValidator()); + } + + std::unique_ptr compiler_; + Validator validator_; +}; + +TEST(TimestampLiteralValidatorTest, FormatsIssues) { + Validator validator; + validator.AddValidation(TimestampLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(cel::ValidationResult result, + compiler->Compile("timestamp('invalid')")); + + validator.UpdateValidationResult(result); + + EXPECT_FALSE(result.IsValid()); + EXPECT_EQ(result.FormatError(), + R"(ERROR: :1:11: invalid timestamp literal + | timestamp('invalid') + | ..........^)"); +} + +TEST(TimestampLiteralValidatorTest, AccumulatesIssues) { + Validator validator; + validator.AddValidation(TimestampLiteralValidator()); + validator.AddValidation(DurationLiteralValidator()); + + constexpr absl::string_view kExpression = R"cel( + [ timestamp('invalid'), + timestamp('9999-12-31T23:59:59Z'), + timestamp('10000-01-01T00:00:00Z') + ].all(t, + t - timestamp(0) < duration('10000s') && + t - timestamp(0) > duration("invalid") + ))cel"; + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(cel::ValidationResult result, + compiler->Compile(kExpression)); + + validator.UpdateValidationResult(result); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), + AllOf(HasSubstr("2:17: invalid timestamp literal"), + HasSubstr("4:17: invalid timestamp literal"), + HasSubstr("7:35: invalid duration literal"))); +} + +struct TestCase { + std::string expression; + bool valid; + std::string error_substr = ""; +}; + +using TimestampLiteralValidatorParameterizedTest = + testing::TestWithParam; + +TEST_P(TimestampLiteralValidatorParameterizedTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(TimestampLiteralValidator()); + validator.AddValidation(DurationLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid); + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + TimestampLiteralValidatorParameterizedTest, + TimestampLiteralValidatorParameterizedTest, + ::testing::Values( + TestCase{"timestamp('2023-01-01T00:00:00Z')", true}, + TestCase{"timestamp('9999-12-31T23:59:59Z')", true}, + TestCase{"timestamp('invalid')", false, "invalid timestamp literal"}, + TestCase{"timestamp('10000-01-01T00:00:00Z')", false, + "invalid timestamp literal"}, + TestCase{"timestamp(0)", true}, + TestCase{"timestamp(-62135596801)", false, + "invalid timestamp literal: Timestamp \"0-12-31T23:59:59Z\" " + "below minimum allowed timestamp \"1-01-01T00:00:00Z\""}, + TestCase{"timestamp(253402300800)", false, + "invalid timestamp literal: Timestamp " + "\"10000-01-01T00:00:00Z\" above maximum allowed timestamp " + "\"9999-12-31T23:59:59.999999999Z\""}, + TestCase{"duration('1s')", true}, + TestCase{"duration('invalid')", false, "invalid duration literal"}, + TestCase{"duration('-1000000000000s')", false, + "below minimum allowed duration"}, + TestCase{"duration('1000000000000s')", false, + "above maximum allowed duration"})); + +} // namespace +} // namespace cel diff --git a/validator/validator.cc b/validator/validator.cc new file mode 100644 index 000000000..e000c71e8 --- /dev/null +++ b/validator/validator.cc @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/validator.h" + +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/source.h" + +namespace cel { + +void Validator::AddValidation(Validation validation) { + ABSL_DCHECK(validation); + if (!validation) return; + validations_.push_back(std::move(validation)); +} + +Validator::ValidationOutput Validator::Validate(const Ast& ast) const { + ValidationOutput result; + ValidationContext context(ast); + for (const auto& validation : validations_) { + if (!validation(context)) { + result.valid = false; + } + } + result.issues = context.ReleaseIssues(); + return result; +} + +void Validator::UpdateValidationResult(ValidationResult& in) const { + if (!in.IsValid() || in.GetAst() == nullptr) { + // If the result is already decided invalid, just return it. + return; + } + + auto result = Validate(*in.GetAst()); + if (!result.valid) { + in.ReleaseAst().IgnoreError(); + } + for (auto& issue : result.issues) { + in.AddIssue(std::move(issue)); + } +} + +void ValidationContext::ReportWarningAt(int64_t id, absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kWarning, + ast_.ComputeSourceLocation(id), + std::string(message))); +} + +void ValidationContext::ReportErrorAt(int64_t id, absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kError, + ast_.ComputeSourceLocation(id), + std::string(message))); +} + +void ValidationContext::ReportWarning(absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kWarning, + SourceLocation{}, std::string(message))); +} + +void ValidationContext::ReportError(absl::string_view message) { + issues_.push_back(TypeCheckIssue(TypeCheckIssue::Severity::kError, + SourceLocation{}, std::string(message))); +} + +} // namespace cel diff --git a/validator/validator.h b/validator/validator.h new file mode 100644 index 000000000..a278bd44f --- /dev/null +++ b/validator/validator.h @@ -0,0 +1,151 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/navigable_ast.h" +namespace cel { + +// Context for a validation pass. +// +// Assumed to be scoped to a Validator::Validate() call. Instances must not +// outlive the `ast` passed to the constructor. +class ValidationContext { + public: + explicit ValidationContext(const Ast& ast ABSL_ATTRIBUTE_LIFETIME_BOUND) + : ast_(ast) {} + + const Ast& ast() const { return ast_; } + const NavigableAst& navigable_ast() const { + if (!navigable_ast_) { + navigable_ast_ = NavigableAst::Build(ast_.root_expr()); + } + return navigable_ast_; + } + + void ReportWarningAt(int64_t id, absl::string_view message); + void ReportErrorAt(int64_t id, absl::string_view message); + void ReportWarning(absl::string_view message); + void ReportError(absl::string_view message); + + std::vector ReleaseIssues() { + auto out = std::move(issues_); + issues_.clear(); + return out; + } + + private: + const Ast& ast_; + mutable NavigableAst navigable_ast_; + std::vector issues_; +}; + +// A single validation to apply to an AST. +// +// May be empty if default constructed or moved from. +// use operator bool() to check if the validation is empty. +class Validation { + public: + // Tests the AST reports any issues to the context. + // + // Returns false if the AST is invalid. + // + // The same instance is used across Validate() so must be thread safe + // (typically stateless). + using ImplFunction = + absl::AnyInvocable; + + Validation() = default; + explicit Validation(ImplFunction impl); + Validation(ImplFunction impl, absl::string_view id); + + const ImplFunction& impl() const { + ABSL_DCHECK(rep_ != nullptr); + return rep_->impl; + } + + absl::string_view id() const { + ABSL_DCHECK(rep_ != nullptr); + return rep_->id; + } + + bool operator()(ValidationContext& context) const { + ABSL_DCHECK(rep_ != nullptr); + return rep_->impl(context); + } + + explicit operator bool() const { return rep_ != nullptr; } + + private: + struct Rep { + ImplFunction impl; + // Optional id if supported in environment config. + std::string id; + }; + + std::shared_ptr rep_; +}; + +// A validator checks a set of semantic rules for a given AST. +class Validator { + public: + Validator() = default; + + void AddValidation(Validation validation); + absl::Span validations() const { return validations_; } + + struct ValidationOutput { + bool valid = true; + std::vector issues; + }; + + // Validates the given AST by applying all of the validations. + ValidationOutput Validate(const Ast& ast) const; + + // Validates the given AST, updating the validation result in place. + // + // Used to apply validators to the output of the type checker. + void UpdateValidationResult(ValidationResult& in) const; + + private: + std::vector validations_; +}; + +// Implementation details. +inline Validation::Validation(ImplFunction impl) + : rep_(std::make_shared( + Validation::Rep{std::move(impl)})) {} + +inline Validation::Validation(ImplFunction impl, absl::string_view id) + : rep_(std::make_shared( + Validation::Rep{std::move(impl), std::string(id)})) {} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_VALIDATOR_H_ diff --git a/validator/validator_test.cc b/validator/validator_test.cc new file mode 100644 index 000000000..744475ec1 --- /dev/null +++ b/validator/validator_test.cc @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/validator.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/source.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Property; + +TEST(ValidatorTest, AddValidationAndValidate) { + Validator validator; + validator.AddValidation(Validation([](ValidationContext& context) { + context.ReportError("error 1"); + return false; + })); + validator.AddValidation(Validation([](ValidationContext& context) { + context.ReportWarning("warning 1"); + return true; + })); + + Ast ast; + auto output = validator.Validate(ast); + + EXPECT_FALSE(output.valid); + EXPECT_THAT(output.issues, + ElementsAre(Property(&TypeCheckIssue::message, Eq("error 1")), + Property(&TypeCheckIssue::message, Eq("warning 1")))); + EXPECT_EQ(output.issues[0].severity(), TypeCheckIssue::Severity::kError); + EXPECT_EQ(output.issues[1].severity(), TypeCheckIssue::Severity::kWarning); +} + +TEST(ValidatorTest, ReportAt) { + Validator validator; + validator.AddValidation(Validation([](ValidationContext& context) { + context.ReportErrorAt(1, "error at 1"); + context.ReportWarningAt(2, "warning at 2"); + return false; + })); + + Expr expr; + expr.set_id(1); + SourceInfo source_info; + source_info.mutable_positions()[1] = 10; + source_info.mutable_positions()[2] = 20; + source_info.set_line_offsets({15, 25}); + + Ast ast(std::move(expr), std::move(source_info)); + auto output = validator.Validate(ast); + + EXPECT_FALSE(output.valid); + ASSERT_EQ(output.issues.size(), 2); + + EXPECT_EQ(output.issues[0].location().line, 1); + EXPECT_EQ(output.issues[0].location().column, 10); + + EXPECT_EQ(output.issues[1].location().line, 2); + EXPECT_EQ(output.issues[1].location().column, 5); +} + +} // namespace +} // namespace cel From ea0deec704ddafab3b4bdad6c50d1906cdbcd7df Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 7 Apr 2026 11:16:40 -0700 Subject: [PATCH 26/88] Add AST depth validator. PiperOrigin-RevId: 895998515 --- validator/BUILD | 26 +++++++++ validator/ast_depth_validator.cc | 34 +++++++++++ validator/ast_depth_validator.h | 27 +++++++++ validator/ast_depth_validator_test.cc | 81 +++++++++++++++++++++++++++ 4 files changed, 168 insertions(+) create mode 100644 validator/ast_depth_validator.cc create mode 100644 validator/ast_depth_validator.h create mode 100644 validator/ast_depth_validator_test.cc diff --git a/validator/BUILD b/validator/BUILD index 65f7fd6b3..dec3d8616 100644 --- a/validator/BUILD +++ b/validator/BUILD @@ -86,4 +86,30 @@ cc_library( ], ) +cc_library( + name = "ast_depth_validator", + srcs = ["ast_depth_validator.cc"], + hdrs = ["ast_depth_validator.h"], + deps = [ + ":validator", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "ast_depth_validator_test", + srcs = ["ast_depth_validator_test.cc"], + deps = [ + ":ast_depth_validator", + ":validator", + "//checker:type_check_issue", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/log:absl_check", + ], +) + licenses(["notice"]) diff --git a/validator/ast_depth_validator.cc b/validator/ast_depth_validator.cc new file mode 100644 index 000000000..0f6b8d93d --- /dev/null +++ b/validator/ast_depth_validator.cc @@ -0,0 +1,34 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/ast_depth_validator.h" + +#include "absl/strings/str_cat.h" +#include "validator/validator.h" + +namespace cel { + +Validation AstDepthValidator(int max_depth) { + return Validation([max_depth](ValidationContext& context) { + int height = context.navigable_ast().Root().height(); + if (height > max_depth) { + context.ReportError(absl::StrCat("AST depth ", height, + " exceeds maximum of ", max_depth)); + return false; + } + return true; + }); +} + +} // namespace cel diff --git a/validator/ast_depth_validator.h b/validator/ast_depth_validator.h new file mode 100644 index 000000000..a640af12e --- /dev/null +++ b/validator/ast_depth_validator.h @@ -0,0 +1,27 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ +#include "validator/validator.h" + +namespace cel { + +// Returns a `Validation` that checks the AST depth is less than or equal to +// max_depth. +Validation AstDepthValidator(int max_depth); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_AST_DEPTH_VALIDATOR_H_ diff --git a/validator/ast_depth_validator_test.cc b/validator/ast_depth_validator_test.cc new file mode 100644 index 000000000..eda59b40d --- /dev/null +++ b/validator/ast_depth_validator_test.cc @@ -0,0 +1,81 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/ast_depth_validator.h" + +#include +#include + +#include "absl/log/absl_check.h" +#include "checker/type_check_issue.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +std::unique_ptr CreateCompiler() { + auto builder = NewCompilerBuilder(internal::GetSharedTestingDescriptorPool()); + ABSL_CHECK_OK(builder); + ABSL_CHECK_OK((*builder)->AddLibrary(StandardCompilerLibrary())); + auto compiler = (*builder)->Build(); + ABSL_CHECK_OK(compiler); + return *std::move(compiler); +} + +TEST(AstDepthValidatorTest, Basic) { + auto compiler = CreateCompiler(); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("1 + 2 + 3")); + + Validator validator; + validator.AddValidation(AstDepthValidator(10)); + auto output = validator.Validate(*result.GetAst()); + EXPECT_TRUE(output.valid); + + Validator validator2; + validator2.AddValidation(AstDepthValidator(2)); + output = validator2.Validate(*result.GetAst()); + EXPECT_FALSE(output.valid); + EXPECT_THAT(output.issues, + testing::Contains(testing::Property( + &TypeCheckIssue::message, + testing::Eq("AST depth 3 exceeds maximum of 2")))); +} + +TEST(AstDepthValidatorTest, Nested) { + auto compiler = CreateCompiler(); + ASSERT_OK_AND_ASSIGN(auto result, + compiler->Compile("1 + (2 + (3 + (4 + 5)))")); + + Validator validator; + validator.AddValidation(AstDepthValidator(10)); + auto output = validator.Validate(*result.GetAst()); + EXPECT_TRUE(output.valid); + + Validator validator2; + validator2.AddValidation(AstDepthValidator(4)); + output = validator2.Validate(*result.GetAst()); + EXPECT_FALSE(output.valid); + EXPECT_THAT(output.issues, + testing::Contains(testing::Property( + &TypeCheckIssue::message, + testing::Eq("AST depth 5 exceeds maximum of 4")))); +} + +} // namespace +} // namespace cel From 1b2a84162ac9f891a22b2b4d3e1018d190184b45 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 7 Apr 2026 14:51:06 -0700 Subject: [PATCH 27/88] Expose TypeSpec formatter in common/ast/metadata.h PiperOrigin-RevId: 896098407 --- common/ast/BUILD | 3 + common/ast/metadata.cc | 117 ++++++++++++++++++++++++++++++++ common/ast/metadata.h | 3 + testutil/baseline_tests.cc | 71 +------------------ testutil/baseline_tests_test.cc | 2 +- 5 files changed, 125 insertions(+), 71 deletions(-) diff --git a/common/ast/BUILD b/common/ast/BUILD index 410d38c65..17456566b 100644 --- a/common/ast/BUILD +++ b/common/ast/BUILD @@ -98,10 +98,13 @@ cc_library( deps = [ "//common:constant", "//common:expr", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", ], diff --git a/common/ast/metadata.cc b/common/ast/metadata.cc index eecb0dbb3..38f7ef610 100644 --- a/common/ast/metadata.cc +++ b/common/ast/metadata.cc @@ -14,11 +14,18 @@ #include "common/ast/metadata.h" +#include #include +#include +#include +#include #include #include "absl/base/no_destructor.h" +#include "absl/base/nullability.h" #include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" #include "absl/types/variant.h" namespace cel { @@ -30,6 +37,96 @@ const TypeSpec& DefaultTypeSpec() { return *type; } +std::string FormatPrimitive(PrimitiveType t) { + switch (t) { + case PrimitiveType::kBool: + return "bool"; + case PrimitiveType::kInt64: + return "int"; + case PrimitiveType::kUint64: + return "uint"; + case PrimitiveType::kDouble: + return "double"; + case PrimitiveType::kString: + return "string"; + case PrimitiveType::kBytes: + return "bytes"; + default: + return "*unspecified primitive*"; + } +} + +std::string FormatWellKnown(WellKnownTypeSpec t) { + switch (t) { + case WellKnownTypeSpec::kAny: + return "google.protobuf.Any"; + case WellKnownTypeSpec::kDuration: + return "google.protobuf.Duration"; + case WellKnownTypeSpec::kTimestamp: + return "google.protobuf.Timestamp"; + default: + return "*unspecified well known*"; + } +} + +using FormatIns = std::variant; +using FormatStack = std::vector; + +void HandleFormatTypeSpec(const TypeSpec& t, FormatStack& stack, + std::string* out) { + if (t.has_dyn()) { + absl::StrAppend(out, "dyn"); + } else if (t.has_null()) { + absl::StrAppend(out, "null"); + } else if (t.has_primitive()) { + absl::StrAppend(out, FormatPrimitive(t.primitive())); + } else if (t.has_wrapper()) { + absl::StrAppend(out, "wrapper(", FormatPrimitive(t.wrapper()), ")"); + } else if (t.has_well_known()) { + absl::StrAppend(out, FormatWellKnown(t.well_known())); + return; + } else if (t.has_abstract_type()) { + const auto& abs_type = t.abstract_type(); + if (abs_type.parameter_types().empty()) { + absl::StrAppend(out, abs_type.name()); + return; + } + absl::StrAppend(out, abs_type.name(), "("); + stack.push_back(")"); + for (size_t i = abs_type.parameter_types().size(); i > 0; --i) { + stack.push_back(&abs_type.parameter_types()[i - 1]); + if (i > 1) { + stack.push_back(", "); + } + } + + } else if (t.has_type()) { + if (t.type() == TypeSpec()) { + absl::StrAppend(out, "type"); + return; + } + absl::StrAppend(out, "type("); + stack.push_back(")"); + stack.push_back(&t.type()); + } else if (t.has_message_type()) { + absl::StrAppend(out, t.message_type().type()); + } else if (t.has_type_param()) { + absl::StrAppend(out, t.type_param().type()); + } else if (t.has_list_type()) { + absl::StrAppend(out, "list("); + stack.push_back(")"); + stack.push_back(&t.list_type().elem_type()); + } else if (t.has_map_type()) { + absl::StrAppend(out, "map("); + stack.push_back(")"); + stack.push_back(&t.map_type().value_type()); + stack.push_back(", "); + stack.push_back(&t.map_type().key_type()); + } else { + absl::StrAppend(out, "*error*"); + } +} + TypeSpecKind CopyImpl(const TypeSpecKind& other) { return absl::visit( absl::Overload( @@ -142,4 +239,24 @@ FunctionTypeSpec& FunctionTypeSpec::operator=(const FunctionTypeSpec& other) { return *this; } +std::string FormatTypeSpec(const TypeSpec& t) { + // Use a stack to avoid recursion. + // Probably overly defensive, but fuzzers will often notice the recursion + // and try to trigger it. + std::string out; + FormatStack seq; + seq.push_back(&t); + while (!seq.empty()) { + FormatIns ins = std::move(seq.back()); + seq.pop_back(); + if (std::holds_alternative(ins)) { + absl::StrAppend(&out, std::get(ins)); + continue; + } + ABSL_DCHECK(std::holds_alternative(ins)); + HandleFormatTypeSpec(*std::get(ins), seq, &out); + } + return out; +} + } // namespace cel diff --git a/common/ast/metadata.h b/common/ast/metadata.h index a82e999f8..197790ff3 100644 --- a/common/ast/metadata.h +++ b/common/ast/metadata.h @@ -740,6 +740,9 @@ class TypeSpec { TypeSpecKind type_kind_; }; +// Returns a string representation of the given TypeSpec. +std::string FormatTypeSpec(const TypeSpec& t); + // Describes a resolved reference to a declaration. class Reference { public: diff --git a/testutil/baseline_tests.cc b/testutil/baseline_tests.cc index 4e56ad485..8ce43e63d 100644 --- a/testutil/baseline_tests.cc +++ b/testutil/baseline_tests.cc @@ -28,75 +28,6 @@ namespace cel::test { namespace { -std::string FormatPrimitive(PrimitiveType t) { - switch (t) { - case PrimitiveType::kBool: - return "bool"; - case PrimitiveType::kInt64: - return "int"; - case PrimitiveType::kUint64: - return "uint"; - case PrimitiveType::kDouble: - return "double"; - case PrimitiveType::kString: - return "string"; - case PrimitiveType::kBytes: - return "bytes"; - default: - return ""; - } -} - -std::string FormatType(const TypeSpec& t) { - if (t.has_dyn()) { - return "dyn"; - } else if (t.has_null()) { - return "null"; - } else if (t.has_primitive()) { - return FormatPrimitive(t.primitive()); - } else if (t.has_wrapper()) { - return absl::StrCat("wrapper(", FormatPrimitive(t.wrapper()), ")"); - } else if (t.has_well_known()) { - switch (t.well_known()) { - case WellKnownTypeSpec::kAny: - return "google.protobuf.Any"; - case WellKnownTypeSpec::kDuration: - return "google.protobuf.Duration"; - case WellKnownTypeSpec::kTimestamp: - return "google.protobuf.Timestamp"; - default: - return ""; - } - } else if (t.has_abstract_type()) { - const auto& abs_type = t.abstract_type(); - std::string s = abs_type.name(); - if (!abs_type.parameter_types().empty()) { - absl::StrAppend(&s, "(", - absl::StrJoin(abs_type.parameter_types(), ",", - [](std::string* out, const auto& t) { - absl::StrAppend(out, FormatType(t)); - }), - ")"); - } - return s; - } else if (t.has_type()) { - if (t.type() == TypeSpec()) { - return "type"; - } - return absl::StrCat("type(", FormatType(t.type()), ")"); - } else if (t.has_message_type()) { - return t.message_type().type(); - } else if (t.has_type_param()) { - return t.type_param().type(); - } else if (t.has_list_type()) { - return absl::StrCat("list(", FormatType(t.list_type().elem_type()), ")"); - } else if (t.has_map_type()) { - return absl::StrCat("map(", FormatType(t.map_type().key_type()), ", ", - FormatType(t.map_type().value_type()), ")"); - } - return ""; -} - std::string FormatReference(const cel::Reference& r) { if (r.overload_id().empty()) { return r.name(); @@ -113,7 +44,7 @@ class TypeAdorner : public ExpressionAdorner { auto t = ast_.type_map().find(e.id()); if (t != ast_.type_map().end()) { - absl::StrAppend(&s, "~", FormatType(t->second)); + absl::StrAppend(&s, "~", FormatTypeSpec(t->second)); } if (const auto r = ast_.reference_map().find(e.id()); r != ast_.reference_map().end()) { diff --git a/testutil/baseline_tests_test.cc b/testutil/baseline_tests_test.cc index 33050583f..f4e89706c 100644 --- a/testutil/baseline_tests_test.cc +++ b/testutil/baseline_tests_test.cc @@ -184,7 +184,7 @@ INSTANTIATE_TEST_SUITE_P( "x~google.protobuf.Timestamp"}, TestCase{TypeSpec(DynTypeSpec()), "x~dyn"}, TestCase{TypeSpec(NullTypeSpec()), "x~null"}, - TestCase{TypeSpec(UnsetTypeSpec()), "x~"}, + TestCase{TypeSpec(UnsetTypeSpec()), "x~*error*"}, TestCase{TypeSpec(MessageTypeSpec("com.example.Type")), "x~com.example.Type"}, TestCase{TypeSpec(AbstractType("optional_type", From 7022451780867bd9c173956f5793b5dfb600928e Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Tue, 7 Apr 2026 16:58:41 -0700 Subject: [PATCH 28/88] Expose EnvRuntime::RegisterExtensionFunctions API PiperOrigin-RevId: 896158131 --- env/BUILD | 7 +++++++ env/config.cc | 8 +++++++- env/env_runtime.cc | 12 ++++++++++++ env/env_runtime.h | 12 ++++++++++++ env/env_runtime_test.cc | 39 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 77 insertions(+), 1 deletion(-) diff --git a/env/BUILD b/env/BUILD index 8d477cc1f..55297b190 100644 --- a/env/BUILD +++ b/env/BUILD @@ -81,7 +81,10 @@ cc_library( "//runtime:runtime_builder_factory", "//runtime:runtime_options", "//runtime:standard_functions", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) @@ -236,10 +239,14 @@ cc_test( "//common:source", "//common:value", "//compiler", + "//extensions:math_ext", + "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//runtime", "//runtime:activation", + "//runtime:runtime_builder", + "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", diff --git a/env/config.cc b/env/config.cc index ccb4de34c..202a607bf 100644 --- a/env/config.cc +++ b/env/config.cc @@ -58,9 +58,15 @@ absl::Status Config::AddExtensionConfig(std::string name, int version) { if (extension_config.version == version) { return absl::OkStatus(); } + std::string version_str; + if (version == ExtensionConfig::kLatest) { + version_str = "'latest'"; + } else { + version_str = absl::StrCat(version); + } return absl::AlreadyExistsError(absl::StrCat( "Extension '", name, "' version ", extension_config.version, - " is already included. Cannot also include version ", version)); + " is already included. Cannot also include version ", version_str)); } } extension_configs_.push_back( diff --git a/env/env_runtime.cc b/env/env_runtime.cc index 09bbcde04..33e0747cc 100644 --- a/env/env_runtime.cc +++ b/env/env_runtime.cc @@ -18,7 +18,10 @@ #include #include +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "env/config.h" #include "internal/status_macros.h" #include "runtime/runtime.h" @@ -29,6 +32,15 @@ namespace cel { +void EnvRuntime::RegisterExtensionFunctions( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable + function_registration_callback) { + extension_registry_.AddFunctionRegistration( + name, alias, version, std::move(function_registration_callback)); +} + absl::StatusOr EnvRuntime::CreateRuntimeBuilder() { const std::vector& extension_configs = config_.GetExtensionConfigs(); diff --git a/env/env_runtime.h b/env/env_runtime.h index ff62ec1d4..63473c295 100644 --- a/env/env_runtime.h +++ b/env/env_runtime.h @@ -18,7 +18,10 @@ #include #include +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "env/config.h" #include "env/internal/runtime_ext_registry.h" #include "runtime/runtime.h" @@ -41,6 +44,15 @@ namespace cel { // compilation. This ensures consistency between compilation and runtime. class EnvRuntime { public: + // Registers a function registration callback for an extension. The callback + // is invoked when a runtime is created, if the corresponding functions are + // enabled in the runtime config. + void RegisterExtensionFunctions( + absl::string_view name, absl::string_view alias, int version, + absl::AnyInvocable + function_registration_callback); + void SetDescriptorPool( std::shared_ptr descriptor_pool) { descriptor_pool_ = std::move(descriptor_pool); diff --git a/env/env_runtime_test.cc b/env/env_runtime_test.cc index 1c4205224..47892772c 100644 --- a/env/env_runtime_test.cc +++ b/env/env_runtime_test.cc @@ -31,10 +31,13 @@ #include "env/env_std_extensions.h" #include "env/env_yaml.h" #include "env/runtime_std_extensions.h" +#include "extensions/math_ext.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "runtime/activation.h" #include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" #include "google/protobuf/arena.h" namespace cel { @@ -156,5 +159,41 @@ std::vector GetEnvRuntimeTestCases() { INSTANTIATE_TEST_SUITE_P(EnvRuntimeTest, EnvRuntimeTest, ValuesIn(GetEnvRuntimeTestCases())); +TEST(EnvRuntimeTest, RegisterExtensionFunctions) { + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + Config config; + ASSERT_THAT(config.AddExtensionConfig("math", 2), IsOk()); + + Env env; + env.SetDescriptorPool(descriptor_pool); + RegisterStandardExtensions(env); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("math.sqrt(4) == 2.0")); + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, result.ReleaseAst()); + + EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + env_runtime.RegisterExtensionFunctions( + "cel.lib.math", "math", 2, + [](cel::RuntimeBuilder& runtime_builder, + const cel::RuntimeOptions& opts) -> absl::Status { + return cel::extensions::RegisterMathExtensionFunctions( + runtime_builder.function_registry(), opts, 2); + }); + env_runtime.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime->CreateProgram(std::move(ast))); + ASSERT_NE(program, nullptr); + + google::protobuf::Arena arena; + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_TRUE(value.GetBool()); +} } // namespace } // namespace cel From e9ba71a59cb8c7c35315c330c79119dc61d68733 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Mon, 13 Apr 2026 09:05:30 -0700 Subject: [PATCH 29/88] Use proto2::ConstMapIterator and stop const_cast'ing. PiperOrigin-RevId: 899029119 --- common/values/parsed_json_map_value.cc | 4 +- common/values/parsed_map_field_value.cc | 26 ++++++------ .../protobuf/internal/map_reflection.cc | 42 ++++++++----------- extensions/protobuf/internal/map_reflection.h | 12 +++--- internal/json.cc | 10 ++--- internal/message_equality.cc | 8 ++-- internal/well_known_types.cc | 12 +++--- internal/well_known_types.h | 9 ++-- 8 files changed, 59 insertions(+), 64 deletions(-) diff --git a/common/values/parsed_json_map_value.cc b/common/values/parsed_json_map_value.cc index 6072a0b21..ec8c91a4f 100644 --- a/common/values/parsed_json_map_value.cc +++ b/common/values/parsed_json_map_value.cc @@ -408,8 +408,8 @@ class ParsedJsonMapValueIterator final : public ValueIterator { private: const google::protobuf::Message* absl_nonnull const message_; const well_known_types::StructReflection reflection_; - google::protobuf::MapIterator begin_; - const google::protobuf::MapIterator end_; + google::protobuf::ConstMapIterator begin_; + const google::protobuf::ConstMapIterator end_; std::string scratch_; }; diff --git a/common/values/parsed_map_field_value.cc b/common/values/parsed_map_field_value.cc index 737593cca..47b737f82 100644 --- a/common/values/parsed_map_field_value.cc +++ b/common/values/parsed_map_field_value.cc @@ -415,10 +415,10 @@ absl::Status ParsedMapFieldValue::ListKeys( field_->message_type()->map_key())); auto builder = NewListValueBuilder(arena); builder->Reserve(Size()); - auto begin = - extensions::protobuf_internal::MapBegin(*reflection, *message_, *field_); - const auto end = - extensions::protobuf_internal::MapEnd(*reflection, *message_, *field_); + auto begin = extensions::protobuf_internal::ConstMapBegin(*reflection, + *message_, *field_); + const auto end = extensions::protobuf_internal::ConstMapEnd( + *reflection, *message_, *field_); for (; begin != end; ++begin) { Value scratch; (*key_accessor)(begin.GetKey(), message_, arena, &scratch); @@ -446,10 +446,10 @@ absl::Status ParsedMapFieldValue::ForEach( CEL_ASSIGN_OR_RETURN( auto value_accessor, common_internal::MapFieldValueAccessorFor(value_field)); - auto begin = extensions::protobuf_internal::MapBegin(*reflection, *message_, - *field_); - const auto end = - extensions::protobuf_internal::MapEnd(*reflection, *message_, *field_); + auto begin = extensions::protobuf_internal::ConstMapBegin( + *reflection, *message_, *field_); + const auto end = extensions::protobuf_internal::ConstMapEnd( + *reflection, *message_, *field_); Value key_scratch; Value value_scratch; for (; begin != end; ++begin) { @@ -479,10 +479,10 @@ class ParsedMapFieldValueIterator final : public ValueIterator { value_field_(field->message_type()->map_value()), key_accessor_(key_accessor), value_accessor_(value_accessor), - begin_(extensions::protobuf_internal::MapBegin( + begin_(extensions::protobuf_internal::ConstMapBegin( *message_->GetReflection(), *message_, *field)), - end_(extensions::protobuf_internal::MapEnd(*message_->GetReflection(), - *message_, *field)) {} + end_(extensions::protobuf_internal::ConstMapEnd( + *message_->GetReflection(), *message_, *field)) {} bool HasNext() override { return begin_ != end_; } @@ -545,8 +545,8 @@ class ParsedMapFieldValueIterator final : public ValueIterator { const google::protobuf::FieldDescriptor* absl_nonnull const value_field_; const absl_nonnull common_internal::MapFieldKeyAccessor key_accessor_; const absl_nonnull common_internal::MapFieldValueAccessor value_accessor_; - google::protobuf::MapIterator begin_; - const google::protobuf::MapIterator end_; + google::protobuf::ConstMapIterator begin_; + const google::protobuf::ConstMapIterator end_; }; } // namespace diff --git a/extensions/protobuf/internal/map_reflection.cc b/extensions/protobuf/internal/map_reflection.cc index 22a6dc23c..605e4437d 100644 --- a/extensions/protobuf/internal/map_reflection.cc +++ b/extensions/protobuf/internal/map_reflection.cc @@ -42,22 +42,16 @@ class CelMapReflectionFriend final { return reflection.MapSize(message, &field); } - static google::protobuf::MapIterator MapBegin(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field) { - return reflection.MapBegin( - const_cast< // NOLINT(google3-runtime-proto-const-cast) - google::protobuf::Message*>(&message), - &field); + static google::protobuf::ConstMapIterator ConstMapBegin( + const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return reflection.ConstMapBegin(&message, &field); } - static google::protobuf::MapIterator MapEnd(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field) { - return reflection.MapEnd( - const_cast< // NOLINT(google3-runtime-proto-const-cast) - google::protobuf::Message*>(&message), - &field); + static google::protobuf::ConstMapIterator ConstMapEnd( + const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return reflection.ConstMapEnd(&message, &field); } static bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, @@ -104,18 +98,18 @@ int MapSize(const google::protobuf::Reflection& reflection, field); } -google::protobuf::MapIterator MapBegin(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field) { - return google::protobuf::expr::CelMapReflectionFriend::MapBegin(reflection, message, - field); +google::protobuf::ConstMapIterator ConstMapBegin(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return google::protobuf::expr::CelMapReflectionFriend::ConstMapBegin(reflection, + message, field); } -google::protobuf::MapIterator MapEnd(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field) { - return google::protobuf::expr::CelMapReflectionFriend::MapEnd(reflection, message, - field); +google::protobuf::ConstMapIterator ConstMapEnd(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field) { + return google::protobuf::expr::CelMapReflectionFriend::ConstMapEnd(reflection, message, + field); } bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, diff --git a/extensions/protobuf/internal/map_reflection.h b/extensions/protobuf/internal/map_reflection.h index 6e696bbe3..681d7693d 100644 --- a/extensions/protobuf/internal/map_reflection.h +++ b/extensions/protobuf/internal/map_reflection.h @@ -42,13 +42,13 @@ int MapSize(const google::protobuf::Reflection& reflection, const google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field); -google::protobuf::MapIterator MapBegin(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field); +google::protobuf::ConstMapIterator ConstMapBegin(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field); -google::protobuf::MapIterator MapEnd(const google::protobuf::Reflection& reflection, - const google::protobuf::Message& message, - const google::protobuf::FieldDescriptor& field); +google::protobuf::ConstMapIterator ConstMapEnd(const google::protobuf::Reflection& reflection, + const google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field); bool InsertOrLookupMapValue(const google::protobuf::Reflection& reflection, google::protobuf::Message* message, diff --git a/internal/json.cc b/internal/json.cc index 200d18bfb..630ceb267 100644 --- a/internal/json.cc +++ b/internal/json.cc @@ -803,10 +803,10 @@ class MessageToJsonState { const auto* value_descriptor = field->message_type()->map_value(); CEL_ASSIGN_OR_RETURN(const auto value_to_value, GetMapFieldValueToValue(value_descriptor)); - auto begin = - extensions::protobuf_internal::MapBegin(*reflection, message, *field); - const auto end = - extensions::protobuf_internal::MapEnd(*reflection, message, *field); + auto begin = extensions::protobuf_internal::ConstMapBegin(*reflection, + message, *field); + const auto end = extensions::protobuf_internal::ConstMapEnd( + *reflection, message, *field); for (; begin != end; ++begin) { auto key = (*key_to_string)(begin.GetKey()); CEL_RETURN_IF_ERROR((this->*value_to_value)( @@ -1381,7 +1381,7 @@ class JsonMapIterator final { using Generated = typename google::protobuf::Map::const_iterator; - using Dynamic = google::protobuf::MapIterator; + using Dynamic = google::protobuf::ConstMapIterator; using Value = std::pair; diff --git a/internal/message_equality.cc b/internal/message_equality.cc index 628432d66..945cca8df 100644 --- a/internal/message_equality.cc +++ b/internal/message_equality.cc @@ -50,9 +50,9 @@ namespace cel::internal { namespace { +using ::cel::extensions::protobuf_internal::ConstMapBegin; +using ::cel::extensions::protobuf_internal::ConstMapEnd; using ::cel::extensions::protobuf_internal::LookupMapValue; -using ::cel::extensions::protobuf_internal::MapBegin; -using ::cel::extensions::protobuf_internal::MapEnd; using ::cel::extensions::protobuf_internal::MapSize; using ::google::protobuf::Descriptor; using ::google::protobuf::DescriptorPool; @@ -904,8 +904,8 @@ class MessageEqualsState final { MapSize(*rhs_reflection, rhs, *rhs_field)) { return false; } - auto lhs_begin = MapBegin(*lhs_reflection, lhs, *lhs_field); - const auto lhs_end = MapEnd(*lhs_reflection, lhs, *lhs_field); + auto lhs_begin = ConstMapBegin(*lhs_reflection, lhs, *lhs_field); + const auto lhs_end = ConstMapEnd(*lhs_reflection, lhs, *lhs_field); Unique lhs_unpacked; EquatableValue lhs_value; Unique rhs_unpacked; diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index c736be69f..dee029534 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -1643,20 +1643,20 @@ int StructReflection::FieldsSize(const google::protobuf::Message& message) const message, *fields_field_); } -google::protobuf::MapIterator StructReflection::BeginFields( +google::protobuf::ConstMapIterator StructReflection::BeginFields( const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); - return cel::extensions::protobuf_internal::MapBegin(*message.GetReflection(), - message, *fields_field_); + return cel::extensions::protobuf_internal::ConstMapBegin( + *message.GetReflection(), message, *fields_field_); } -google::protobuf::MapIterator StructReflection::EndFields( +google::protobuf::ConstMapIterator StructReflection::EndFields( const google::protobuf::Message& message) const { ABSL_DCHECK(IsInitialized()); ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); - return cel::extensions::protobuf_internal::MapEnd(*message.GetReflection(), - message, *fields_field_); + return cel::extensions::protobuf_internal::ConstMapEnd( + *message.GetReflection(), message, *fields_field_); } bool StructReflection::ContainsField(const google::protobuf::Message& message, diff --git a/internal/well_known_types.h b/internal/well_known_types.h index dce88a420..f63e5e76b 100644 --- a/internal/well_known_types.h +++ b/internal/well_known_types.h @@ -698,8 +698,9 @@ absl::StatusOr GetAnyReflection( const google::protobuf::Descriptor* absl_nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); -AnyReflection GetAnyReflectionOrDie(const google::protobuf::Descriptor* absl_nonnull - descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); +AnyReflection GetAnyReflectionOrDie( + const google::protobuf::Descriptor* absl_nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); class DurationReflection final { public: @@ -1193,10 +1194,10 @@ class StructReflection final { int FieldsSize(const google::protobuf::Message& message) const; - google::protobuf::MapIterator BeginFields( + google::protobuf::ConstMapIterator BeginFields( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; - google::protobuf::MapIterator EndFields( + google::protobuf::ConstMapIterator EndFields( const google::protobuf::Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND) const; bool ContainsField(const google::protobuf::Message& message, From 15009ff50e3eb011b8e0ba3b46d55c559f03ac95 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 13 Apr 2026 15:00:28 -0700 Subject: [PATCH 30/88] Add homogeneous literal validator. Ported from the go implementation. PiperOrigin-RevId: 899199115 --- validator/BUILD | 35 ++++ validator/homogeneous_literal_validator.cc | 190 ++++++++++++++++++ validator/homogeneous_literal_validator.h | 38 ++++ .../homogeneous_literal_validator_test.cc | 145 +++++++++++++ 4 files changed, 408 insertions(+) create mode 100644 validator/homogeneous_literal_validator.cc create mode 100644 validator/homogeneous_literal_validator.h create mode 100644 validator/homogeneous_literal_validator_test.cc diff --git a/validator/BUILD b/validator/BUILD index dec3d8616..98d1316c7 100644 --- a/validator/BUILD +++ b/validator/BUILD @@ -96,6 +96,41 @@ cc_library( ], ) +cc_library( + name = "homogeneous_literal_validator", + srcs = ["homogeneous_literal_validator.cc"], + hdrs = ["homogeneous_literal_validator.h"], + deps = [ + ":validator", + "//common:ast", + "//common:expr", + "//common:navigable_ast", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "homogeneous_literal_validator_test", + srcs = ["homogeneous_literal_validator_test.cc"], + deps = [ + ":homogeneous_literal_validator", + ":validator", + "//checker:validation_result", + "//common:decl", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:optional", + "//compiler:standard_library", + "//extensions:strings", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + cc_test( name = "ast_depth_validator_test", srcs = ["ast_depth_validator_test.cc"], diff --git a/validator/homogeneous_literal_validator.cc b/validator/homogeneous_literal_validator.cc new file mode 100644 index 000000000..4a490dea2 --- /dev/null +++ b/validator/homogeneous_literal_validator.cc @@ -0,0 +1,190 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/homogeneous_literal_validator.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/navigable_ast.h" +#include "validator/validator.h" + +namespace cel { + +namespace { + +bool InExemptFunction(const NavigableAstNode& node, + const std::vector& exempt_functions) { + const NavigableAstNode* parent = node.parent(); + while (parent != nullptr) { + if (parent->node_kind() == NodeKind::kCall) { + absl::string_view fn_name = parent->expr()->call_expr().function(); + for (const auto& exempt : exempt_functions) { + if (exempt == fn_name) { + return true; + } + } + } + parent = parent->parent(); + } + return false; +} + +bool IsOptional(const TypeSpec& t) { + return t.has_abstract_type() && t.abstract_type().name() == "optional_type"; +} + +const TypeSpec& GetOptionalParameter(const TypeSpec& t) { + return t.abstract_type().parameter_types()[0]; +} + +void TypeMismatch(ValidationContext& context, int64_t id, + const TypeSpec& expected, const TypeSpec& actual) { + context.ReportErrorAt( + id, absl::StrCat("expected type '", FormatTypeSpec(expected), + "' but found '", FormatTypeSpec(actual), "'")); +} + +bool TypeEquiv(const TypeSpec& a, const TypeSpec& b) { + if (a == b) { + return true; + } + + if (a.has_error() || b.has_error()) { + // Don't report mismatch if there's an error (type checking failed for the + // expression). + return true; + } + + if (a.has_wrapper() && b.has_primitive()) { + return a.wrapper() == b.primitive(); + } else if (a.has_primitive() && b.has_wrapper()) { + return a.primitive() == b.wrapper(); + } + + if (a.has_list_type() && b.has_list_type()) { + return TypeEquiv(a.list_type().elem_type(), b.list_type().elem_type()); + } + + if (a.has_map_type() && b.has_map_type()) { + return TypeEquiv(a.map_type().key_type(), b.map_type().key_type()) && + TypeEquiv(a.map_type().value_type(), b.map_type().value_type()); + } + + if (a.has_abstract_type() && b.has_abstract_type() && + a.abstract_type().name() == b.abstract_type().name() && + a.abstract_type().parameter_types().size() == + b.abstract_type().parameter_types().size()) { + for (int i = 0; i < a.abstract_type().parameter_types().size(); ++i) { + if (!TypeEquiv(a.abstract_type().parameter_types()[i], + b.abstract_type().parameter_types()[i])) { + return false; + } + } + return true; + } + + return false; +} + +} // namespace + +Validation HomogeneousLiteralValidator( + std::vector exempt_functions) { + return Validation([exempt_functions = std::move(exempt_functions)]( + ValidationContext& context) -> bool { + bool valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() == NodeKind::kList) { + if (InExemptFunction(node, exempt_functions)) { + continue; + } + const auto& list_expr = node.expr()->list_expr(); + const auto& elements = list_expr.elements(); + const TypeSpec* expected_type = nullptr; + + for (const auto& element : elements) { + int64_t id = element.expr().id(); + const TypeSpec& actual_type = context.ast().GetTypeOrDyn(id); + const TypeSpec* type_to_check = &actual_type; + + if (element.optional() && IsOptional(actual_type)) { + type_to_check = &GetOptionalParameter(actual_type); + } + + if (expected_type == nullptr) { + expected_type = type_to_check; + continue; + } + + if (!(TypeEquiv(*expected_type, *type_to_check))) { + TypeMismatch(context, id, *expected_type, *type_to_check); + valid = false; + break; + } + } + } else if (node.node_kind() == NodeKind::kMap) { + if (InExemptFunction(node, exempt_functions)) { + continue; + } + const auto& map_expr = node.expr()->map_expr(); + const auto& entries = map_expr.entries(); + const TypeSpec* expected_key_type = nullptr; + const TypeSpec* expected_value_type = nullptr; + + for (const auto& entry : entries) { + int64_t key_id = entry.key().id(); + int64_t val_id = entry.value().id(); + const TypeSpec& actual_key_type = context.ast().GetTypeOrDyn(key_id); + const TypeSpec& actual_val_type = context.ast().GetTypeOrDyn(val_id); + const TypeSpec* key_type_to_check = &actual_key_type; + const TypeSpec* val_type_to_check = &actual_val_type; + + if (entry.optional() && IsOptional(actual_val_type)) { + val_type_to_check = &GetOptionalParameter(actual_val_type); + } + + if (expected_key_type == nullptr) { + expected_key_type = key_type_to_check; + expected_value_type = val_type_to_check; + continue; + } + + if (!(TypeEquiv(*expected_key_type, *key_type_to_check))) { + TypeMismatch(context, key_id, *expected_key_type, + *key_type_to_check); + valid = false; + break; + } + if (!(TypeEquiv(*expected_value_type, *val_type_to_check))) { + TypeMismatch(context, val_id, *expected_value_type, + *val_type_to_check); + valid = false; + break; + } + } + } + } + return valid; + }); +} + +} // namespace cel diff --git a/validator/homogeneous_literal_validator.h b/validator/homogeneous_literal_validator.h new file mode 100644 index 000000000..e37648a25 --- /dev/null +++ b/validator/homogeneous_literal_validator.h @@ -0,0 +1,38 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ + +#include +#include + +#include "validator/validator.h" + +namespace cel { + +// Returns a `Validation` that checks that all literals in map or list literals +// are the same type. If the list or map is part of an argument to an exempted +// function, it is not checked. +Validation HomogeneousLiteralValidator( + std::vector exempt_functions); + +inline Validation HomogeneousLiteralValidator() { + // Default to exempting the strings extension "format" function. + return HomogeneousLiteralValidator({"format"}); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_HOMOGENEOUS_LITERAL_VALIDATOR_H_ diff --git a/validator/homogeneous_literal_validator_test.cc b/validator/homogeneous_literal_validator_test.cc new file mode 100644 index 000000000..b027fa4b0 --- /dev/null +++ b/validator/homogeneous_literal_validator_test.cc @@ -0,0 +1,145 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/homogeneous_literal_validator.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/validation_result.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/optional.h" +#include "compiler/standard_library.h" +#include "extensions/strings.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); + builder->AddLibrary(OptionalCompilerLibrary()).IgnoreError(); + builder->AddLibrary(extensions::StringsCompilerLibrary()).IgnoreError(); + cel::Type message_type = cel::Type::Message( + builder->GetCheckerBuilder().descriptor_pool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("msg", message_type))); + return builder->Build(); +} + +struct TestCase { + std::string expression; + bool valid; + std::string error_substr = ""; +}; + +using HomogeneousLiteralValidatorTest = testing::TestWithParam; + +TEST_P(HomogeneousLiteralValidatorTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(HomogeneousLiteralValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid); + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + HomogeneousLiteralValidatorTest, HomogeneousLiteralValidatorTest, + testing::Values( + // Lists + TestCase{"[1, 2, 3]", true}, TestCase{"['a', 'b', 'c']", true}, + TestCase{"[1, 'a']", false, "expected type 'int' but found 'string'"}, + TestCase{"[1, 2, 'a']", false, + "expected type 'int' but found 'string'"}, + TestCase{"[[1], [2]]", true}, + TestCase{"[[1], ['a']]", false, + "expected type 'list(int)' but found 'list(string)'"}, + + // Dyn casts + TestCase{"[dyn(1), dyn('a')]", true, ""}, + TestCase{"[dyn(1), 2]", false, "expected type 'dyn' but found 'int'"}, + + // Maps + TestCase{"{1: 'a', 2: 'b'}", true}, TestCase{"{'a': 1, 'b': 2}", true}, + TestCase{"{1: 'a', 'b': 2}", false, + "expected type 'int' but found 'string'"}, + TestCase{"{1: 'a', 2: 3}", false, + "expected type 'string' but found 'int'"}, + + // Optionals + TestCase{"[optional.of(1), optional.of(2)]", true}, + TestCase{"[optional.of(1), optional.of('b')]", false, + "expected type 'optional_type(int)' but found " + "'optional_type(string)'"}, + + TestCase{"[?optional.of(1), ?optional.of(2)]", true}, + TestCase{"[?optional.of(1), ?optional.of('a')]", false, + "expected type 'int' but found 'string'"}, + TestCase{"{?1: optional.of('a'), ?2: optional.none()}", true}, + TestCase{"{?1: optional.of('a'), ?2: optional.of(1)}", false, + "expected type 'string' but found 'int'"}, + + // Exempted Functions + TestCase{"'%v %v'.format([1, 'a'])", true}, + + // Mixed Primitives and Wrappers + TestCase{"[1, msg.single_int64_wrapper]", true}, + TestCase{"[msg.single_int64_wrapper, 1]", true}, + TestCase{"['foo', msg.single_string_wrapper]", true}, + TestCase{"[msg.single_string_wrapper, 'foo']", true}, + TestCase{"{1: msg.single_int64_wrapper, 2: 3}", true}, + TestCase{"{1: 2, 2: msg.single_int64_wrapper}", true}, + TestCase{"[[1], [msg.single_int64_wrapper]]", true}, + TestCase{"[optional.of(1), optional.of(msg.single_int64_wrapper)]", + true}, + TestCase{"[1, msg.single_string_wrapper]", false, + "expected type 'int' but found 'wrapper(string)'"}, + TestCase{"[msg.single_int64_wrapper, 'foo']", false, + "expected type 'wrapper(int)' but found 'string'"}, + TestCase{"[msg.single_int64_wrapper, msg.single_string_wrapper]", false, + "expected type 'wrapper(int)' but found 'wrapper(string)'"}, + + // Nested + TestCase{"[1, [2, 'a']]", false, + "expected type 'int' but found 'string'"}, + TestCase{"[[1, 2], [3, 4]]", true, ""}, + TestCase{"[{1: 2}, {'foo': 3}]", false, + "expected type 'map(int, int)' but found 'map(string, int)'"}, + TestCase{"[{1: 2}, {3: 'foo'}]", false, + "expected type 'map(int, int)' but found 'map(int, string)'"}, + TestCase{"[{1: 2}, {3: 4}]", true, ""})); + +} // namespace +} // namespace cel From 072542b0a1128e1c6b26ce774dc9c2824cb98ad2 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 13 Apr 2026 16:41:07 -0700 Subject: [PATCH 31/88] Add RegexPatternValidator. PiperOrigin-RevId: 899243591 --- extensions/BUILD | 3 + extensions/regex_ext.cc | 8 +++ extensions/regex_ext.h | 8 +++ extensions/regex_ext_test.cc | 40 +++++++++++++ validator/BUILD | 35 +++++++++++ validator/regex_validator.cc | 96 +++++++++++++++++++++++++++++++ validator/regex_validator.h | 53 +++++++++++++++++ validator/regex_validator_test.cc | 91 +++++++++++++++++++++++++++++ 8 files changed, 334 insertions(+) create mode 100644 validator/regex_validator.cc create mode 100644 validator/regex_validator.h create mode 100644 validator/regex_validator_test.cc diff --git a/extensions/BUILD b/extensions/BUILD index c393ec13a..ff37e2c3f 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -774,6 +774,8 @@ cc_library( "//runtime:runtime_builder", "//runtime/internal:runtime_friend_access", "//runtime/internal:runtime_impl", + "//validator", + "//validator:regex_validator", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:bind_front", @@ -814,6 +816,7 @@ cc_test( "//runtime:reference_resolver", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", + "//validator", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", diff --git a/extensions/regex_ext.cc b/extensions/regex_ext.cc index c3d7cae53..9c06d90c2 100644 --- a/extensions/regex_ext.cc +++ b/extensions/regex_ext.cc @@ -42,6 +42,8 @@ #include "runtime/internal/runtime_friend_access.h" #include "runtime/internal/runtime_impl.h" #include "runtime/runtime_builder.h" +#include "validator/regex_validator.h" +#include "validator/validator.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -341,4 +343,10 @@ CompilerLibrary RegexExtCompilerLibrary() { return CompilerLibrary::FromCheckerLibrary(RegexExtCheckerLibrary()); } +Validation RegexExtValidator() { + return RegexPatternValidator( + /*id=*/"", + {{"regex.extract", 1}, {"regex.extractAll", 1}, {"regex.replace", 1}}); +} + } // namespace cel::extensions diff --git a/extensions/regex_ext.h b/extensions/regex_ext.h index dc401f5bd..7b32aee00 100644 --- a/extensions/regex_ext.h +++ b/extensions/regex_ext.h @@ -81,6 +81,7 @@ #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "runtime/runtime_builder.h" +#include "validator/validator.h" namespace cel::extensions { @@ -119,5 +120,12 @@ CheckerLibrary RegexExtCheckerLibrary(); // regex.extractAll(target: str, pattern: str) -> list CompilerLibrary RegexExtCompilerLibrary(); +// Returns a `Validation` that checks all calls to the CEL regex extension +// functions. +// +// It validates that if the pattern is a literal string, it is a valid regular +// expression. +Validation RegexExtValidator(); + } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ diff --git a/extensions/regex_ext_test.cc b/extensions/regex_ext_test.cc index e69f7cce1..26d9936aa 100644 --- a/extensions/regex_ext_test.cc +++ b/extensions/regex_ext_test.cc @@ -46,6 +46,7 @@ #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" +#include "validator/validator.h" #include "google/protobuf/arena.h" #include "google/protobuf/extension_set.h" @@ -497,5 +498,44 @@ std::vector createRegexCheckerParams() { INSTANTIATE_TEST_SUITE_P(RegexExtCheckerLibraryTest, RegexExtCheckerLibraryTest, ValuesIn(createRegexCheckerParams())); + +absl::StatusOr> CreateRegexExtCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, NewCompilerBuilder(internal::GetTestingDescriptorPool())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(StandardCheckerLibrary())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(RegexExtCompilerLibrary())); + return std::move(*builder).Build(); +} + +class RegexExtValidatorTest : public TestWithParam {}; + +TEST_P(RegexExtValidatorTest, Basic) { + ASSERT_OK_AND_ASSIGN(auto compiler, CreateRegexExtCompiler()); + + Validator validator; + validator.AddValidation(RegexExtValidator()); + + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(GetParam().expr_string)); + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), GetParam().error_substr.empty()) + << "Expression: " << GetParam().expr_string; + if (!GetParam().error_substr.empty()) { + EXPECT_THAT(result.FormatError(), HasSubstr(GetParam().error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P(RegexExtValidatorTest, RegexExtValidatorTest, + testing::ValuesIn(std::vector{ + {"regex.extract('hello world', 'hello (.*)')"}, + {"regex.extract('hello world', 'hello ([') ", + "invalid regular expression"}, + {"regex.extractAll('hello world', 'hello (.*)')"}, + {"regex.extractAll('hello world', 'hello ([') ", + "invalid regular expression"}, + {"regex.replace('hello world', 'hello', 'hi')"}, + {"regex.replace('hello world', 'he([', 'hi') ", + "invalid regular expression"}, + })); } // namespace } // namespace cel::extensions diff --git a/validator/BUILD b/validator/BUILD index 98d1316c7..e3f639142 100644 --- a/validator/BUILD +++ b/validator/BUILD @@ -109,6 +109,24 @@ cc_library( ], ) +cc_library( + name = "regex_validator", + srcs = ["regex_validator.cc"], + hdrs = ["regex_validator.h"], + deps = [ + ":validator", + "//common:ast", + "//common:constant", + "//common:expr", + "//common:navigable_ast", + "//common:standard_definitions", + "//internal:re2_options", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + ], +) + cc_test( name = "homogeneous_literal_validator_test", srcs = ["homogeneous_literal_validator_test.cc"], @@ -147,4 +165,21 @@ cc_test( ], ) +cc_test( + name = "regex_validator_test", + srcs = ["regex_validator_test.cc"], + deps = [ + ":regex_validator", + ":validator", + "//common:decl", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:statusor", + ], +) + licenses(["notice"]) diff --git a/validator/regex_validator.cc b/validator/regex_validator.cc new file mode 100644 index 000000000..df92bfb1e --- /dev/null +++ b/validator/regex_validator.cc @@ -0,0 +1,96 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/regex_validator.h" + +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/constant.h" +#include "common/expr.h" +#include "common/navigable_ast.h" +#include "internal/re2_options.h" +#include "validator/validator.h" +#include "re2/re2.h" + +namespace cel { + +namespace { + +bool CheckPattern(ValidationContext& context, const NavigableAstNode& node, + int arg_index) { + ABSL_DCHECK(node.expr()->has_call_expr()); + const auto& call_expr = node.expr()->call_expr(); + + const Expr* pattern_expr = nullptr; + + if (call_expr.has_target()) { + if (arg_index == 0) { + pattern_expr = &call_expr.target(); + } else if (call_expr.args().size() > arg_index - 1) { + pattern_expr = &call_expr.args()[arg_index - 1]; + } + } else if (call_expr.args().size() > arg_index) { + pattern_expr = &call_expr.args()[arg_index]; + } + + if (pattern_expr == nullptr || !pattern_expr->has_const_expr()) { + return true; + } + + const auto& const_expr = pattern_expr->const_expr(); + if (!const_expr.has_string_value()) { + return true; + } + + absl::string_view pattern_string = const_expr.string_value(); + RE2 re(pattern_string, internal::MakeRE2Options()); + if (!re.ok()) { + context.ReportErrorAt( + pattern_expr->id(), + absl::StrCat("invalid regular expression: ", re.error())); + return false; + } + return true; +} + +} // namespace + +Validation RegexPatternValidator( + absl::string_view id, std::vector config) { + return Validation( + [config = std::move(config)](ValidationContext& context) -> bool { + bool result = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() == NodeKind::kCall) { + for (const auto& config : config) { + if (node.expr()->call_expr().function() == config.function_name) { + if (!CheckPattern(context, node, config.pattern_arg_index)) { + result = false; + } + break; + } + } + } + } + return result; + }, + id); +} + +} // namespace cel diff --git a/validator/regex_validator.h b/validator/regex_validator.h new file mode 100644 index 000000000..15ee1755e --- /dev/null +++ b/validator/regex_validator.h @@ -0,0 +1,53 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "common/standard_definitions.h" +#include "validator/validator.h" + +namespace cel { + +// Configuration for the regex pattern validator. +struct RegexPatternValidatorConfig { + // The resolved function name. + std::string function_name; + // the index of the pattern argument (counting the receiver as arg 0 if + // present). + int pattern_arg_index; +}; + +// Returns a `Validation` that checks all calls to the given regex functions +// It validates that the specified argument is a valid regular expression if it +// is a literal string. +Validation RegexPatternValidator( + absl::string_view id, std::vector config); + +// Returns a `Validation` that checks all calls to the CEL `matches` function. +// It validates that if the pattern is a literal string, it is a valid regular +// expression. +inline Validation MatchesValidator() { + return RegexPatternValidator( + "cel.validator.matches", + {{std::string(StandardFunctions::kRegexMatch), 1}}); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_REGEX_VALIDATOR_H_ diff --git a/validator/regex_validator_test.cc b/validator/regex_validator_test.cc new file mode 100644 index 000000000..cfab1468d --- /dev/null +++ b/validator/regex_validator_test.cc @@ -0,0 +1,91 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/regex_validator.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "common/decl.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + builder->AddLibrary(StandardCompilerLibrary()).IgnoreError(); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("p", StringType()))); + return builder->Build(); +} + +struct TestCase { + std::string expression; + bool valid; + std::string error_substr = ""; +}; + +using MatchesValidatorTest = testing::TestWithParam; + +TEST_P(MatchesValidatorTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(MatchesValidator()); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile(test_case.expression)); + + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid) + << "Expression: " << test_case.expression; + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + MatchesValidatorTest, MatchesValidatorTest, + testing::Values( + // Member calls + TestCase{"'hello'.matches('h.*')", true}, + TestCase{"'hello'.matches('h[')", false, "invalid regular expression"}, + TestCase{"'hello'.matches('h(a|b)')", true}, + TestCase{"'hello'.matches('h(a|b')", false, + "invalid regular expression"}, + // Global calls + TestCase{"matches('hello', 'h.*')", true}, + TestCase{"matches('hello', 'h[')", false, "invalid regular expression"}, + // Non-literal patterns (should not report regex errors) + TestCase{"'hello'.matches(p)", true}, + TestCase{"'hello'.matches('h' + 'ello')", true}, + TestCase{"'hello'.matches(dyn(1))", true}, + + // Empty pattern + TestCase{"'hello'.matches('')", true})); + +} // namespace +} // namespace cel From 9fd4d79dff5f5b95654f1eb60a55a104e063deac Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 13 Apr 2026 17:10:35 -0700 Subject: [PATCH 32/88] Add comprehension nesting limit validator. PiperOrigin-RevId: 899255266 --- validator/BUILD | 29 ++++++ validator/comprehension_nesting_validator.cc | 72 ++++++++++++++ validator/comprehension_nesting_validator.h | 31 ++++++ .../comprehension_nesting_validator_test.cc | 96 +++++++++++++++++++ 4 files changed, 228 insertions(+) create mode 100644 validator/comprehension_nesting_validator.cc create mode 100644 validator/comprehension_nesting_validator.h create mode 100644 validator/comprehension_nesting_validator_test.cc diff --git a/validator/BUILD b/validator/BUILD index e3f639142..9910a6b97 100644 --- a/validator/BUILD +++ b/validator/BUILD @@ -182,4 +182,33 @@ cc_test( ], ) +cc_library( + name = "comprehension_nesting_validator", + srcs = ["comprehension_nesting_validator.cc"], + hdrs = ["comprehension_nesting_validator.h"], + deps = [ + ":validator", + "//common:expr", + "//common:navigable_ast", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "comprehension_nesting_validator_test", + srcs = ["comprehension_nesting_validator_test.cc"], + deps = [ + ":comprehension_nesting_validator", + ":validator", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//extensions:bindings_ext", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status:statusor", + ], +) + licenses(["notice"]) diff --git a/validator/comprehension_nesting_validator.cc b/validator/comprehension_nesting_validator.cc new file mode 100644 index 000000000..81c47cbc3 --- /dev/null +++ b/validator/comprehension_nesting_validator.cc @@ -0,0 +1,72 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/comprehension_nesting_validator.h" + +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "common/expr.h" +#include "common/navigable_ast.h" +#include "validator/validator.h" + +namespace cel { + +namespace { + +bool IsEmptyRangeComprehension(const NavigableAstNode& node) { + ABSL_DCHECK(node.expr()->has_comprehension_expr()); + const auto& comp = node.expr()->comprehension_expr(); + return comp.has_iter_range() && comp.iter_range().has_list_expr() && + comp.iter_range().list_expr().elements().empty(); +} + +} // namespace + +Validation ComprehensionNestingLimitValidator(int limit) { + return Validation( + [limit](ValidationContext& context) -> bool { + bool is_valid = true; + for (const auto& node : + context.navigable_ast().Root().DescendantsPostorder()) { + if (node.node_kind() != NodeKind::kComprehension) { + continue; + } + if (IsEmptyRangeComprehension(node)) { + continue; + } + + int count = 0; + const NavigableAstNode* current = &node; + while (current != nullptr) { + if (current->node_kind() == NodeKind::kComprehension && + !IsEmptyRangeComprehension(*current)) { + count++; + } + current = current->parent(); + } + if (count > limit) { + context.ReportErrorAt( + node.expr()->id(), + absl::StrCat("comprehension nesting level of ", count, + " exceeds limit of ", limit)); + is_valid = false; + break; + } + } + return is_valid; + }, + "cel.validator.comprehension_nesting_limit"); +} + +} // namespace cel diff --git a/validator/comprehension_nesting_validator.h b/validator/comprehension_nesting_validator.h new file mode 100644 index 000000000..4dab78db0 --- /dev/null +++ b/validator/comprehension_nesting_validator.h @@ -0,0 +1,31 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ +#define THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ + +#include "validator/validator.h" + +namespace cel { + +// Returns a `Validation` that checks that comprehensions are not nested beyond +// the specified limit. +// +// Comprehensions with an empty iteration range (e.g. `cel.bind`) do not count +// towards the nesting limit. +Validation ComprehensionNestingLimitValidator(int limit); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_VALIDATOR_COMPREHENSION_NESTING_VALIDATOR_H_ diff --git a/validator/comprehension_nesting_validator_test.cc b/validator/comprehension_nesting_validator_test.cc new file mode 100644 index 000000000..c1b47f82d --- /dev/null +++ b/validator/comprehension_nesting_validator_test.cc @@ -0,0 +1,96 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "validator/comprehension_nesting_validator.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "extensions/bindings_ext.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "validator/validator.h" + +namespace cel { +namespace { + +using ::testing::HasSubstr; + +absl::StatusOr> StdLibCompiler() { + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCompilerLibrary())); + return builder->Build(); +} + +struct TestCase { + std::string expression; + int limit; + bool valid; + std::string error_substr = ""; +}; + +using ComprehensionNestingValidatorTest = testing::TestWithParam; + +TEST_P(ComprehensionNestingValidatorTest, Validate) { + const auto& test_case = GetParam(); + Validator validator; + validator.AddValidation(ComprehensionNestingLimitValidator(test_case.limit)); + + ASSERT_OK_AND_ASSIGN(auto compiler, StdLibCompiler()); + auto result_or = compiler->Compile(test_case.expression); + if (!result_or.ok()) { + GTEST_SKIP() << "Expression failed to compile: " << test_case.expression + << " " << result_or.status().message(); + } + auto result = std::move(result_or).value(); + + validator.UpdateValidationResult(result); + + EXPECT_EQ(result.IsValid(), test_case.valid) + << "Expression: " << test_case.expression + << " Limit: " << test_case.limit; + if (!test_case.valid) { + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.error_substr)); + } +} + +INSTANTIATE_TEST_SUITE_P( + ComprehensionNestingValidatorTest, ComprehensionNestingValidatorTest, + testing::Values( + TestCase{"[1, 2].all(x, x > 0)", 1, true}, + TestCase{"[1, 2].all(x, [1, 2].all(y, x > y))", 1, false, + "comprehension nesting level of 2 exceeds limit of 1"}, + TestCase{"[1, 2].all(x, [1, 2].all(y, x > y))", 2, true}, + // Empty range comprehension (does not count) + TestCase{"[].all(x, [1, 2].all(y, y > 0))", 1, true}, + TestCase{"cel.bind(x, [1, 2].all(y, y > 0), [1, 2].all(z, z > 0))", 1, + true}, + // Nested empty range comprehensions + TestCase{"[].all(x, [].all(y, true))", 0, true}, + // Deeply nested mixed + TestCase{"[1].all(x, [].all(y, [2].all(z, true)))", 1, false, + "comprehension nesting level of 2 exceeds limit of 1"}, + TestCase{"[1].all(x, [].all(y, [2].all(z, true)))", 2, true})); + +} // namespace +} // namespace cel From 7a37461941067c8fe90cce8724dd5153b42c21bf Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 14 Apr 2026 13:12:00 -0700 Subject: [PATCH 33/88] Test fixes for macos builds - ignore conformance test that depends on charconv shortest float rep formatting - fix internal test with ambiguous overloads PiperOrigin-RevId: 899739434 --- conformance/BUILD | 44 ++------- internal/overflow_test.cc | 189 ++++++++++++++++++++++++-------------- 2 files changed, 126 insertions(+), 107 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index ba485f36d..139739891 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -164,7 +164,7 @@ _ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/type_deduction.textproto", ] -_TESTS_TO_SKIP_MODERN = [ +_TESTS_TO_SKIP = [ # Tests which require spec changes. # TODO(issues/93): Deprecate Duration.getMilliseconds. "timestamps/duration_converters/get_milliseconds", @@ -197,10 +197,13 @@ _TESTS_TO_SKIP_MODERN = [ "timestamps/timestamp_selectors_tz/getDayOfMonth_name_pos", "timestamps/timestamp_selectors_tz/getDayOfYear", # These depend on using charconv (or equivalent) to format doubles with shortest possible - # precision to preserve value. Not available on older compilers. + # precision to preserve value. Not available on older compilers where we just use absl::Format. + # We should probably update the spec to allow different formats that parse to the same value. "conversions/string/double_hard", ] +_TESTS_TO_SKIP_MODERN = _TESTS_TO_SKIP + _TESTS_TO_SKIP_MODERN_DASHBOARD = [ # Future features for CEL 1.0 # TODO(issues/119): Strong typing support for enums, specified but not implemented. @@ -208,34 +211,7 @@ _TESTS_TO_SKIP_MODERN_DASHBOARD = [ "enums/strong_proto3", ] -_TESTS_TO_SKIP_LEGACY = [ - # Tests which require spec changes. - # TODO(issues/93): Deprecate Duration.getMilliseconds. - "timestamps/duration_converters/get_milliseconds", - - # Broken test cases which should be supported. - # TODO(issues/112): Unbound functions result in empty eval response. - "basic/functions/unbound", - "basic/functions/unbound_is_runtime_error", - - # TODO(issues/97): Parse-only qualified variable lookup "x.y" with binding "x.y" or "y" within container "x" fails - "fields/qualified_identifier_resolution/qualified_ident,map_field_select,ident_with_longest_prefix_check,qualified_identifier_resolution_unchecked", - "namespace/qualified/self_eval_qualified_lookup", - "namespace/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", - # TODO(issues/117): Integer overflow on enum assignments should error. - "enums/legacy_proto2/select_big,select_neg", - - # Skip until fixed. - "wrappers/field_mask/to_json", - "wrappers/empty/to_json", - "fields/qualified_identifier_resolution/map_value_repeat_key_heterogeneous", - "parse/receiver_function_names", - - # Future features for CEL 1.0 - # TODO(issues/119): Strong typing support for enums, specified but not implemented. - "enums/strong_proto2", - "enums/strong_proto3", - +_TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [ # Legacy value does not support optional_type. "optionals/optionals", @@ -245,14 +221,6 @@ _TESTS_TO_SKIP_LEGACY = [ "proto3/set_null/list_value", "proto3/set_null/single_struct", - # These depend on legacy US/ timezones. It's spotty if these are included with a normally - # configured timezone database. - "timestamps/timestamp_selectors_tz/getDayOfMonth_name_pos", - "timestamps/timestamp_selectors_tz/getDayOfYear", - # These depend on using charconv (or equivalent) to format doubles with shortest possible - # precision to preserve value. Not available on older compilers. - "conversions/string/double_hard", - # cel.@block "block_ext/basic/optional_list", "block_ext/basic/optional_map", diff --git a/internal/overflow_test.cc b/internal/overflow_test.cc index 38c5fa750..213e7a79d 100644 --- a/internal/overflow_test.cc +++ b/internal/overflow_test.cc @@ -57,25 +57,30 @@ INSTANTIATE_TEST_SUITE_P( CheckedIntMathTest, CheckedIntResultTest, ValuesIn(std::vector{ // Addition tests. - {"OneAddOne", [] { return CheckedAdd(1L, 1L); }, 2L}, - {"ZeroAddOne", [] { return CheckedAdd(0, 1L); }, 1L}, - {"ZeroAddMinusOne", [] { return CheckedAdd(0, -1L); }, -1L}, - {"OneAddZero", [] { return CheckedAdd(1L, 0); }, 1L}, - {"MinusOneAddZero", [] { return CheckedAdd(-1L, 0); }, -1L}, + {"OneAddOne", [] { return CheckedAdd(int64_t{1L}, 1L); }, 2L}, + {"ZeroAddOne", [] { return CheckedAdd(int64_t{0}, 1L); }, 1L}, + {"ZeroAddMinusOne", [] { return CheckedAdd(int64_t{0}, -1L); }, -1L}, + {"OneAddZero", [] { return CheckedAdd(int64_t{1L}, 0); }, 1L}, + {"MinusOneAddZero", [] { return CheckedAdd(int64_t{-1L}, 0); }, -1L}, {"OneAddIntMax", - [] { return CheckedAdd(1L, std::numeric_limits::max()); }, + [] { + return CheckedAdd(int64_t{1L}, std::numeric_limits::max()); + }, absl::OutOfRangeError("integer overflow")}, {"MinusOneAddIntMin", - [] { return CheckedAdd(-1L, std::numeric_limits::lowest()); }, + [] { + return CheckedAdd(int64_t{-1L}, + std::numeric_limits::lowest()); + }, absl::OutOfRangeError("integer overflow")}, // Subtraction tests. - {"TwoSubThree", [] { return CheckedSub(2L, 3L); }, -1L}, - {"TwoSubZero", [] { return CheckedSub(2L, 0); }, 2L}, - {"ZeroSubTwo", [] { return CheckedSub(0, 2L); }, -2L}, - {"MinusTwoSubThree", [] { return CheckedSub(-2L, 3L); }, -5L}, - {"MinusTwoSubZero", [] { return CheckedSub(-2L, 0); }, -2L}, - {"ZeroSubMinusTwo", [] { return CheckedSub(0, -2L); }, 2L}, + {"TwoSubThree", [] { return CheckedSub(int64_t{2L}, 3L); }, -1L}, + {"TwoSubZero", [] { return CheckedSub(int64_t{2L}, 0); }, 2L}, + {"ZeroSubTwo", [] { return CheckedSub(int64_t{0}, 2L); }, -2L}, + {"MinusTwoSubThree", [] { return CheckedSub(int64_t{-2L}, 3L); }, -5L}, + {"MinusTwoSubZero", [] { return CheckedSub(int64_t{-2L}, 0); }, -2L}, + {"ZeroSubMinusTwo", [] { return CheckedSub(int64_t{0}, -2L); }, 2L}, {"IntMinSubIntMax", [] { return CheckedSub(std::numeric_limits::max(), @@ -84,66 +89,100 @@ INSTANTIATE_TEST_SUITE_P( absl::OutOfRangeError("integer overflow")}, // Multiplication tests. - {"TwoMulThree", [] { return CheckedMul(2L, 3L); }, 6L}, - {"MinusTwoMulThree", [] { return CheckedMul(-2L, 3L); }, -6L}, - {"MinusTwoMulMinusThree", [] { return CheckedMul(-2L, -3L); }, 6L}, - {"TwoMulMinusThree", [] { return CheckedMul(2L, -3L); }, -6L}, + {"TwoMulThree", [] { return CheckedMul(int64_t{2L}, 3L); }, 6L}, + {"MinusTwoMulThree", [] { return CheckedMul(int64_t{-2L}, 3L); }, -6L}, + {"MinusTwoMulMinusThree", [] { return CheckedMul(int64_t{-2L}, -3L); }, + 6L}, + {"TwoMulMinusThree", [] { return CheckedMul(int64_t{2L}, -3L); }, -6L}, {"TwoMulIntMax", - [] { return CheckedMul(2L, std::numeric_limits::max()); }, + [] { + return CheckedMul(int64_t{2L}, std::numeric_limits::max()); + }, absl::OutOfRangeError("integer overflow")}, {"MinusOneMulIntMin", - [] { return CheckedMul(-1L, std::numeric_limits::lowest()); }, + [] { + return CheckedMul(int64_t{-1L}, + std::numeric_limits::lowest()); + }, absl::OutOfRangeError("integer overflow")}, {"IntMinMulMinusOne", - [] { return CheckedMul(std::numeric_limits::lowest(), -1L); }, + [] { + return CheckedMul(std::numeric_limits::lowest(), + int64_t{-1L}); + }, absl::OutOfRangeError("integer overflow")}, {"IntMinMulZero", - [] { return CheckedMul(std::numeric_limits::lowest(), 0); }, + [] { + return CheckedMul(std::numeric_limits::lowest(), + int64_t{0}); + }, 0}, {"ZeroMulIntMin", - [] { return CheckedMul(0, std::numeric_limits::lowest()); }, + [] { + return CheckedMul(int64_t{0}, + std::numeric_limits::lowest()); + }, 0}, {"IntMaxMulZero", - [] { return CheckedMul(std::numeric_limits::max(), 0); }, 0}, + [] { + return CheckedMul(std::numeric_limits::max(), int64_t{0}); + }, + 0}, {"ZeroMulIntMax", - [] { return CheckedMul(0, std::numeric_limits::max()); }, 0}, + [] { + return CheckedMul(int64_t{0}, std::numeric_limits::max()); + }, + 0}, // Division cases. - {"ZeroDivOne", [] { return CheckedDiv(0, 1L); }, 0}, - {"TenDivTwo", [] { return CheckedDiv(10L, 2L); }, 5}, - {"TenDivMinusOne", [] { return CheckedDiv(10L, -1L); }, -10}, - {"MinusTenDivMinusOne", [] { return CheckedDiv(-10L, -1L); }, 10}, - {"MinusTenDivTwo", [] { return CheckedDiv(-10L, 2L); }, -5}, - {"OneDivZero", [] { return CheckedDiv(1L, 0L); }, + {"ZeroDivOne", [] { return CheckedDiv(int64_t{0}, 1L); }, 0}, + {"TenDivTwo", [] { return CheckedDiv(int64_t{10L}, 2L); }, 5}, + {"TenDivMinusOne", [] { return CheckedDiv(int64_t{10L}, -1L); }, -10}, + {"MinusTenDivMinusOne", [] { return CheckedDiv(int64_t{-10L}, -1L); }, + 10}, + {"MinusTenDivTwo", [] { return CheckedDiv(int64_t{-10L}, 2L); }, -5}, + {"OneDivZero", [] { return CheckedDiv(int64_t{1L}, 0L); }, absl::InvalidArgumentError("divide by zero")}, {"IntMinDivMinusOne", - [] { return CheckedDiv(std::numeric_limits::lowest(), -1L); }, + [] { + return CheckedDiv(std::numeric_limits::lowest(), + int64_t{-1L}); + }, absl::OutOfRangeError("integer overflow")}, // Modulus cases. - {"ZeroModTwo", [] { return CheckedMod(0, 2L); }, 0}, - {"TwoModTwo", [] { return CheckedMod(2L, 2L); }, 0}, - {"ThreeModTwo", [] { return CheckedMod(3L, 2L); }, 1L}, - {"TwoModZero", [] { return CheckedMod(2L, 0); }, + {"ZeroModTwo", [] { return CheckedMod(int64_t{0}, 2L); }, 0}, + {"TwoModTwo", [] { return CheckedMod(int64_t{2L}, 2L); }, 0}, + {"ThreeModTwo", [] { return CheckedMod(int64_t{3L}, 2L); }, 1L}, + {"TwoModZero", [] { return CheckedMod(int64_t{2L}, 0); }, absl::InvalidArgumentError("modulus by zero")}, {"IntMinModTwo", - [] { return CheckedMod(std::numeric_limits::lowest(), 2L); }, + [] { + return CheckedMod(std::numeric_limits::lowest(), + int64_t{2L}); + }, 0}, {"IntMaxModMinusOne", - [] { return CheckedMod(std::numeric_limits::max(), -1L); }, + [] { + return CheckedMod(std::numeric_limits::max(), int64_t{-1L}); + }, 0}, {"IntMinModMinusOne", - [] { return CheckedMod(std::numeric_limits::lowest(), -1L); }, + [] { + return CheckedMod(std::numeric_limits::lowest(), + int64_t{-1L}); + }, absl::OutOfRangeError("integer overflow")}, // Negation cases. - {"NegateOne", [] { return CheckedNegation(1L); }, -1L}, + {"NegateOne", [] { return CheckedNegation(int64_t{1L}); }, -1L}, {"NegateMinInt64", [] { return CheckedNegation(std::numeric_limits::lowest()); }, absl::OutOfRangeError("integer overflow")}, // Numeric conversion cases for uint -> int, double -> int - {"Uint64Conversion", [] { return CheckedUint64ToInt64(1UL); }, 1L}, + {"Uint64Conversion", [] { return CheckedUint64ToInt64(uint64_t{1UL}); }, + 1L}, {"Uint32MaxConversion", [] { return CheckedUint64ToInt64( @@ -156,7 +195,8 @@ INSTANTIATE_TEST_SUITE_P( static_cast(std::numeric_limits::max())); }, absl::OutOfRangeError("out of int64 range")}, - {"DoubleConversion", [] { return CheckedDoubleToInt64(100.1); }, 100L}, + {"DoubleConversion", [] { return CheckedDoubleToInt64(double{100.1}); }, + 100L}, {"DoubleInt64MaxConversionError", [] { return CheckedDoubleToInt64( @@ -201,9 +241,10 @@ INSTANTIATE_TEST_SUITE_P( }, absl::OutOfRangeError("out of int64 range")}, {"NegRangeConversionError", - [] { return CheckedDoubleToInt64(-1.0e99); }, + [] { return CheckedDoubleToInt64(double{-1.0e99}); }, absl::OutOfRangeError("out of int64 range")}, - {"PosRangeConversionError", [] { return CheckedDoubleToInt64(1.0e99); }, + {"PosRangeConversionError", + [] { return CheckedDoubleToInt64(double{1.0e99}); }, absl::OutOfRangeError("out of int64 range")}, }), [](const testing::TestParamInfo& info) { @@ -218,51 +259,58 @@ INSTANTIATE_TEST_SUITE_P( CheckedUintMathTest, CheckedUintResultTest, ValuesIn(std::vector{ // Addition tests. - {"OneAddOne", [] { return CheckedAdd(1UL, 1UL); }, 2UL}, - {"ZeroAddOne", [] { return CheckedAdd(0, 1UL); }, 1UL}, - {"OneAddZero", [] { return CheckedAdd(1UL, 0); }, 1UL}, + {"OneAddOne", [] { return CheckedAdd(uint64_t{1UL}, 1UL); }, 2UL}, + {"ZeroAddOne", [] { return CheckedAdd(uint64_t{0}, 1UL); }, 1UL}, + {"OneAddZero", [] { return CheckedAdd(uint64_t{1UL}, 0); }, 1UL}, {"OneAddIntMax", - [] { return CheckedAdd(1UL, std::numeric_limits::max()); }, + [] { + return CheckedAdd(uint64_t{1UL}, + std::numeric_limits::max()); + }, absl::OutOfRangeError("unsigned integer overflow")}, // Subtraction tests. - {"OneSubOne", [] { return CheckedSub(1UL, 1UL); }, 0}, - {"ZeroSubOne", [] { return CheckedSub(0, 1UL); }, + {"OneSubOne", [] { return CheckedSub(uint64_t{1UL}, 1UL); }, 0}, + {"ZeroSubOne", [] { return CheckedSub(uint64_t{0}, 1UL); }, absl::OutOfRangeError("unsigned integer overflow")}, - {"OneSubZero", [] { return CheckedSub(1UL, 0); }, 1UL}, + {"OneSubZero", [] { return CheckedSub(uint64_t{1UL}, 0); }, 1UL}, // Multiplication tests. - {"OneMulOne", [] { return CheckedMul(1UL, 1UL); }, 1UL}, - {"ZeroMulOne", [] { return CheckedMul(0, 1UL); }, 0}, - {"OneMulZero", [] { return CheckedMul(1UL, 0); }, 0}, + {"OneMulOne", [] { return CheckedMul(uint64_t{1UL}, 1UL); }, 1UL}, + {"ZeroMulOne", [] { return CheckedMul(uint64_t{0}, 1UL); }, 0}, + {"OneMulZero", [] { return CheckedMul(uint64_t{1UL}, 0); }, 0}, {"TwoMulUintMax", - [] { return CheckedMul(2UL, std::numeric_limits::max()); }, + [] { + return CheckedMul(uint64_t{2UL}, + std::numeric_limits::max()); + }, absl::OutOfRangeError("unsigned integer overflow")}, // Division tests. - {"TwoDivTwo", [] { return CheckedDiv(2UL, 2UL); }, 1UL}, - {"TwoDivFour", [] { return CheckedDiv(2UL, 4UL); }, 0}, - {"OneDivZero", [] { return CheckedDiv(1UL, 0); }, + {"TwoDivTwo", [] { return CheckedDiv(uint64_t{2UL}, 2UL); }, 1UL}, + {"TwoDivFour", [] { return CheckedDiv(uint64_t{2UL}, 4UL); }, 0}, + {"OneDivZero", [] { return CheckedDiv(uint64_t{1UL}, 0); }, absl::InvalidArgumentError("divide by zero")}, // Modulus tests. - {"TwoModTwo", [] { return CheckedMod(2UL, 2UL); }, 0}, - {"TwoModFour", [] { return CheckedMod(2UL, 4UL); }, 2UL}, - {"OneModZero", [] { return CheckedMod(1UL, 0); }, + {"TwoModTwo", [] { return CheckedMod(uint64_t{2UL}, 2UL); }, 0}, + {"TwoModFour", [] { return CheckedMod(uint64_t{2UL}, 4UL); }, 2UL}, + {"OneModZero", [] { return CheckedMod(uint64_t{1UL}, 0); }, absl::InvalidArgumentError("modulus by zero")}, // Conversion test cases for int -> uint, double -> uint. - {"Int64Conversion", [] { return CheckedInt64ToUint64(1L); }, 1UL}, + {"Int64Conversion", [] { return CheckedInt64ToUint64(int64_t{1L}); }, + 1UL}, {"Int64MaxConversion", [] { return CheckedInt64ToUint64(std::numeric_limits::max()); }, static_cast(std::numeric_limits::max())}, {"NegativeInt64ConversionError", - [] { return CheckedInt64ToUint64(-1L); }, + [] { return CheckedInt64ToUint64(int64_t{-1L}); }, absl::OutOfRangeError("out of uint64 range")}, - {"DoubleConversion", [] { return CheckedDoubleToUint64(100.1); }, - 100UL}, + {"DoubleConversion", + [] { return CheckedDoubleToUint64(double{100.1}); }, 100UL}, {"DoubleUint64MaxConversionError", [] { return CheckedDoubleToUint64( @@ -287,13 +335,14 @@ INSTANTIATE_TEST_SUITE_P( std::numeric_limits::infinity()); }, absl::OutOfRangeError("out of uint64 range")}, - {"NegConversionError", [] { return CheckedDoubleToUint64(-1.1); }, + {"NegConversionError", + [] { return CheckedDoubleToUint64(double{-1.1}); }, absl::OutOfRangeError("out of uint64 range")}, {"NegRangeConversionError", - [] { return CheckedDoubleToUint64(-1.0e99); }, + [] { return CheckedDoubleToUint64(double{-1.0e99}); }, absl::OutOfRangeError("out of uint64 range")}, {"PosRangeConversionError", - [] { return CheckedDoubleToUint64(1.0e99); }, + [] { return CheckedDoubleToUint64(double{1.0e99}); }, absl::OutOfRangeError("out of uint64 range")}, }), [](const testing::TestParamInfo& info) { @@ -571,7 +620,8 @@ TEST_P(CheckedConvertInt64Int32Test, Conversions) { ExpectResult(GetParam()); } INSTANTIATE_TEST_SUITE_P( CheckedConvertInt64Int32Test, CheckedConvertInt64Int32Test, ValuesIn(std::vector{ - {"SimpleConversion", [] { return CheckedInt64ToInt32(1L); }, 1}, + {"SimpleConversion", [] { return CheckedInt64ToInt32(int64_t{1L}); }, + 1}, {"Int32MaxConversion", [] { return CheckedInt64ToInt32( @@ -610,7 +660,8 @@ TEST_P(CheckedConvertUint64Uint32Test, Conversions) { INSTANTIATE_TEST_SUITE_P( CheckedConvertUint64Uint32Test, CheckedConvertUint64Uint32Test, ValuesIn(std::vector{ - {"SimpleConversion", [] { return CheckedUint64ToUint32(1UL); }, 1U}, + {"SimpleConversion", + [] { return CheckedUint64ToUint32(uint64_t{1UL}); }, 1U}, {"Uint32MaxConversion", [] { return CheckedUint64ToUint32( From a69b0ea03670d65745564a931d350ba00e63ce54 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 14 Apr 2026 14:31:16 -0700 Subject: [PATCH 34/88] More macos test fixes. PiperOrigin-RevId: 899776008 --- env/env_yaml.cc | 2 +- env/env_yaml_test.cc | 3 ++- eval/public/builtin_func_test.cc | 42 +++++++++++++++++++------------- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/env/env_yaml.cc b/env/env_yaml.cc index a6f66bd83..4ba16ea84 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -882,7 +882,7 @@ void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { case ConstantKindCase::kTimestamp: out << YAML::Key << "value" << YAML::Value; out << absl::FormatTime( - "%4Y-%2m-%2d%ET%2H:%2M:%E*SZ", + "%Y-%m-%d%ET%H:%M:%E*SZ", // NOLINTNEXTLINE(clang-diagnostic-deprecated-declarations) constant.timestamp_value(), absl::UTCTimeZone()); break; diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index c3e4839af..25cc63206 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -880,7 +880,8 @@ INSTANTIATE_TEST_SUITE_P( })); std::string Unindent(std::string_view yaml) { - std::vector lines = absl::StrSplit(yaml, '\n'); + absl::string_view yaml_view = yaml; + std::vector lines = absl::StrSplit(yaml_view, '\n'); int indent = -1; std::vector unindented_lines; for (auto& line : lines) { diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 1eeb07193..dba71d307 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -750,22 +750,25 @@ TEST_F(BuiltinsTest, TestBytesConversions_string) { TEST_F(BuiltinsTest, TestDoubleConversions_double) { double ref = 100.1; - TestTypeConverts(builtin::kDouble, CelValue::CreateDouble(ref), 100.1); + TestTypeConverts(builtin::kDouble, CelValue::CreateDouble(ref), + double{100.1}); } TEST_F(BuiltinsTest, TestDoubleConversions_int) { int64_t ref = 100L; - TestTypeConverts(builtin::kDouble, CelValue::CreateInt64(ref), 100.0); + TestTypeConverts(builtin::kDouble, CelValue::CreateInt64(ref), double{100.0}); } TEST_F(BuiltinsTest, TestDoubleConversions_string) { std::string ref = "-100.1"; - TestTypeConverts(builtin::kDouble, CelValue::CreateString(&ref), -100.1); + TestTypeConverts(builtin::kDouble, CelValue::CreateString(&ref), + double{-100.1}); } TEST_F(BuiltinsTest, TestDoubleConversions_uint) { uint64_t ref = 100UL; - TestTypeConverts(builtin::kDouble, CelValue::CreateUint64(ref), 100.0); + TestTypeConverts(builtin::kDouble, CelValue::CreateUint64(ref), + double{100.0}); } TEST_F(BuiltinsTest, TestDoubleConversionError_stringInvalid) { @@ -774,34 +777,36 @@ TEST_F(BuiltinsTest, TestDoubleConversionError_stringInvalid) { } TEST_F(BuiltinsTest, TestDynConversions) { - TestTypeConverts(builtin::kDyn, CelValue::CreateDouble(100.1), 100.1); - TestTypeConverts(builtin::kDyn, CelValue::CreateInt64(100L), 100L); - TestTypeConverts(builtin::kDyn, CelValue::CreateUint64(100UL), 100UL); + TestTypeConverts(builtin::kDyn, CelValue::CreateDouble(100.1), double{100.1}); + TestTypeConverts(builtin::kDyn, CelValue::CreateInt64(100L), int64_t{100L}); + TestTypeConverts(builtin::kDyn, CelValue::CreateUint64(100UL), + uint64_t{100UL}); } TEST_F(BuiltinsTest, TestIntConversions_int) { - TestTypeConverts(builtin::kInt, CelValue::CreateInt64(100L), 100L); + TestTypeConverts(builtin::kInt, CelValue::CreateInt64(100L), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_Timestamp) { Timestamp ref; ref.set_seconds(100); - TestTypeConverts(builtin::kInt, CelProtoWrapper::CreateTimestamp(&ref), 100L); + TestTypeConverts(builtin::kInt, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_double) { double ref = 100.1; - TestTypeConverts(builtin::kInt, CelValue::CreateDouble(ref), 100L); + TestTypeConverts(builtin::kInt, CelValue::CreateDouble(ref), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_string) { std::string ref = "100"; - TestTypeConverts(builtin::kInt, CelValue::CreateString(&ref), 100L); + TestTypeConverts(builtin::kInt, CelValue::CreateString(&ref), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_uint) { uint64_t ref = 100; - TestTypeConverts(builtin::kInt, CelValue::CreateUint64(ref), 100L); + TestTypeConverts(builtin::kInt, CelValue::CreateUint64(ref), int64_t{100L}); } TEST_F(BuiltinsTest, TestIntConversions_doubleIntMin) { @@ -826,7 +831,7 @@ TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMaxMinus512) { // value, but it will rountrip to a valid 64-bit integer. double range = std::numeric_limits::max() - 512; TestTypeConverts(builtin::kInt, CelValue::CreateDouble(range), - std::numeric_limits::max() - 1023); + int64_t{std::numeric_limits::max() - 1023}); } TEST_F(BuiltinsTest, TestIntConversionError_doubleNegRange) { @@ -874,21 +879,24 @@ TEST_F(BuiltinsTest, TestIntConversionError_uintRange) { TEST_F(BuiltinsTest, TestUintConversions_double) { double ref = 100.1; - TestTypeConverts(builtin::kUint, CelValue::CreateDouble(ref), 100UL); + TestTypeConverts(builtin::kUint, CelValue::CreateDouble(ref), + uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversions_int) { int64_t ref = 100L; - TestTypeConverts(builtin::kUint, CelValue::CreateInt64(ref), 100UL); + TestTypeConverts(builtin::kUint, CelValue::CreateInt64(ref), uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversions_string) { std::string ref = "100"; - TestTypeConverts(builtin::kUint, CelValue::CreateString(&ref), 100UL); + TestTypeConverts(builtin::kUint, CelValue::CreateString(&ref), + uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversions_uint) { - TestTypeConverts(builtin::kUint, CelValue::CreateUint64(100UL), 100UL); + TestTypeConverts(builtin::kUint, CelValue::CreateUint64(uint64_t{100UL}), + uint64_t{100UL}); } TEST_F(BuiltinsTest, TestUintConversionError_doubleNegRange) { From 0253936925f3d5274c24c8ffa7e5eed683ecdd41 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 15 Apr 2026 11:31:54 -0700 Subject: [PATCH 35/88] No-op doc change to trigger builds. PiperOrigin-RevId: 900266859 --- common/ast.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/ast.h b/common/ast.h index db336f52d..afd0575ad 100644 --- a/common/ast.h +++ b/common/ast.h @@ -136,10 +136,10 @@ class Ast final { expr_version_ = expr_version; } - // Computes the source location (line and column) for the given expression id + // Computes the source location (line and column) for the given expression ID // from the source info (which stores absolute positions). // - // Returns a default (empty) source location if the expression id is not found + // Returns a default (empty) source location if the expression ID is not found // or the source info is not populated correctly. SourceLocation ComputeSourceLocation(int64_t expr_id) const; From bbe1aaa9dd1caa60c48a0d2278b3577c3d1d5797 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 16 Apr 2026 15:36:24 -0700 Subject: [PATCH 36/88] Remove special error for null select target. Just return the general invalid select target error. PiperOrigin-RevId: 900943253 --- eval/eval/select_step.cc | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index b95915145..420f3ac31 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -19,7 +19,6 @@ #include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" -#include "eval/internal/errors.h" #include "internal/status_macros.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" @@ -158,13 +157,6 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { result_trail = trail.Step(&field_); } - if (arg->Is()) { - frame->value_stack().PopAndPush( - cel::ErrorValue(cel::runtime_internal::CreateError("Message is NULL")), - std::move(result_trail)); - return absl::OkStatus(); - } - absl::optional optional_arg; if (enable_optional_types_ && arg.IsOptional()) { @@ -354,10 +346,6 @@ class DirectSelectStep : public DirectExpressionStep { case ValueKind::kStruct: case ValueKind::kMap: break; - case ValueKind::kNull: - result = cel::ErrorValue( - cel::runtime_internal::CreateError("Message is NULL")); - return absl::OkStatus(); default: if (optional_arg) { break; From 69719524b2c99de9802239acc7a8f1cb8712955c Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 17 Apr 2026 12:02:08 -0700 Subject: [PATCH 37/88] No public description PiperOrigin-RevId: 901405918 --- internal/to_address.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/to_address.h b/internal/to_address.h index 5dffef3c1..36e7eeb60 100644 --- a/internal/to_address.h +++ b/internal/to_address.h @@ -49,7 +49,7 @@ struct PointerTraitsToAddress { template struct PointerTraitsToAddress< - T, absl::void_t::to_address( + T, std::void_t::to_address( std::declval()))> > { static constexpr auto Dispatch( const T& p ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { From 77d98c5b13884b10f516b42034929505e3e46ab6 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 17 Apr 2026 15:17:59 -0700 Subject: [PATCH 38/88] Disable incompatible test on MacOS PiperOrigin-RevId: 901492992 --- checker/internal/format_type_name_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/checker/internal/format_type_name_test.cc b/checker/internal/format_type_name_test.cc index 23bc2bda9..ff04e04d2 100644 --- a/checker/internal/format_type_name_test.cc +++ b/checker/internal/format_type_name_test.cc @@ -101,6 +101,7 @@ TEST(FormatTypeNameTest, Opaque) { "tuple(tuple(int, int), tuple(int, int), tuple(int, int))"); } +#ifndef __APPLE__ TEST(FormatTypeNameTest, ArbitraryNesting) { google::protobuf::Arena arena; Type type = IntType(); @@ -111,6 +112,7 @@ TEST(FormatTypeNameTest, ArbitraryNesting) { EXPECT_THAT(FormatTypeName(type), MatchesRegex(R"(^(ptype\(){1000}int(\)){1000})")); } +#endif } // namespace } // namespace cel::checker_internal From e20bf48de055dcd02a517151624f21bff2eda160 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 17 Apr 2026 15:18:25 -0700 Subject: [PATCH 39/88] Ensure RequestContext is in the descriptor pool Mark benchmark tests as 'manual' PiperOrigin-RevId: 901493173 --- eval/tests/BUILD | 20 ++++++++++++++++---- eval/tests/allocation_benchmark_test.cc | 3 +++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 8eeafd521..9163548d1 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -18,7 +18,10 @@ cc_test( srcs = [ "benchmark_test.cc", ], - tags = ["benchmark"], + tags = [ + "benchmark", + "manual", + ], deps = [ ":request_context_cc_proto", "//eval/public:activation", @@ -52,7 +55,10 @@ cc_test( srcs = [ "modern_benchmark_test.cc", ], - tags = ["benchmark"], + tags = [ + "benchmark", + "manual", + ], deps = [ ":request_context_cc_proto", "//common:allocator", @@ -102,7 +108,10 @@ cc_test( srcs = [ "allocation_benchmark_test.cc", ], - tags = ["benchmark"], + tags = [ + "benchmark", + "manual", + ], deps = [ ":request_context_cc_proto", "//eval/public:activation", @@ -151,7 +160,10 @@ cc_test( srcs = [ "expression_builder_benchmark_test.cc", ], - tags = ["benchmark"], + tags = [ + "benchmark", + "manual", + ], deps = [ ":request_context_cc_proto", "//common:minimal_descriptor_pool", diff --git a/eval/tests/allocation_benchmark_test.cc b/eval/tests/allocation_benchmark_test.cc index 5364d3fc0..425355e3a 100644 --- a/eval/tests/allocation_benchmark_test.cc +++ b/eval/tests/allocation_benchmark_test.cc @@ -169,6 +169,9 @@ static void BM_AllocateMessage(benchmark::State& state) { "google.api.expr.runtime.RequestContext{" "ip: '192.168.0.1'," "path: '/root'}"); + // Make sure RequestContext is loaded in the generated descriptor pool. + RequestContext context; + static_cast(context); auto builder = CreateCelExpressionBuilder(); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); From c62346262ea2ecaa98c98847651210d94af70348 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 17 Apr 2026 15:27:27 -0700 Subject: [PATCH 40/88] Add missing 'alwayslink' directive PiperOrigin-RevId: 901496760 --- testing/testrunner/user_tests/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/testing/testrunner/user_tests/BUILD b/testing/testrunner/user_tests/BUILD index 140b77aef..53cd8f716 100644 --- a/testing/testrunner/user_tests/BUILD +++ b/testing/testrunner/user_tests/BUILD @@ -59,6 +59,7 @@ cc_library( "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", "@com_google_protobuf//:protobuf", ], + alwayslink = True, ) cc_library( From ecece1a7758a1219854f04d68de34bc73471ec76 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 17 Apr 2026 15:35:23 -0700 Subject: [PATCH 41/88] Add missing absl/str_cat deps PiperOrigin-RevId: 901499946 --- internal/BUILD | 2 ++ internal/strings_test.cc | 1 + internal/testing.cc | 2 ++ 3 files changed, 5 insertions(+) diff --git a/internal/BUILD b/internal/BUILD index 6bd0f0a46..3891c635d 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -296,6 +296,7 @@ cc_library( deps = [ ":status_macros", "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], ) @@ -312,6 +313,7 @@ cc_library( deps = [ ":status_macros", "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", ], ) diff --git a/internal/strings_test.cc b/internal/strings_test.cc index d6c90473e..fcdb6d4ec 100644 --- a/internal/strings_test.cc +++ b/internal/strings_test.cc @@ -24,6 +24,7 @@ #include "absl/strings/cord.h" #include "absl/strings/cord_test_helpers.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "internal/testing.h" diff --git a/internal/testing.cc b/internal/testing.cc index 77e4c65b4..84aa58cce 100644 --- a/internal/testing.cc +++ b/internal/testing.cc @@ -14,6 +14,8 @@ #include "internal/testing.h" +#include "absl/strings/str_cat.h" // IWYU pragma: keep + namespace cel::internal { void AddFatalFailure(const char* file, int line, absl::string_view expression, From 170a758a0f9c5963630b074ece3f28bb52158883 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 17 Apr 2026 15:49:49 -0700 Subject: [PATCH 42/88] Use ShortDebugString instead of implicit proto string serialization PiperOrigin-RevId: 901506584 --- internal/message_equality_test.cc | 53 ++++++++++++++++++------------- parser/parser_test.cc | 16 ++++++---- 2 files changed, 41 insertions(+), 28 deletions(-) diff --git a/internal/message_equality_test.cc b/internal/message_equality_test.cc index bc5914bef..318138d9b 100644 --- a/internal/message_equality_test.cc +++ b/internal/message_equality_test.cc @@ -110,22 +110,22 @@ TEST_P(UnaryMessageEqualsTest, Equals) { } EXPECT_THAT(MessageEquals(*lhs, *rhs, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs << " " << *rhs; + << lhs->ShortDebugString() << " " << rhs->ShortDebugString(); EXPECT_THAT(MessageEquals(*rhs, *lhs, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs << " " << *rhs; + << lhs->ShortDebugString() << " " << rhs->ShortDebugString(); // Test any. auto lhs_any = PackMessage(*lhs); auto rhs_any = PackMessage(*rhs); EXPECT_THAT(MessageEquals(*lhs_any, *rhs, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any << " " << *rhs; + << lhs_any->ShortDebugString() << " " << rhs->ShortDebugString(); EXPECT_THAT(MessageEquals(*lhs, *rhs_any, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs << " " << *rhs_any; + << lhs->ShortDebugString() << " " << rhs_any->ShortDebugString(); EXPECT_THAT(MessageEquals(*lhs_any, *rhs_any, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any << " " << *rhs_any; + << lhs_any->ShortDebugString() << " " << rhs_any->ShortDebugString(); } } } @@ -455,28 +455,30 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() << " " + << rhs_message->ShortDebugString() << " " << rhs_field->name(); EXPECT_THAT(MessageFieldEquals(*rhs_message, rhs_field, *lhs_message, lhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() << " " + << rhs_message->ShortDebugString() << " " << rhs_field->name(); if (!lhs_field->is_repeated() && lhs_field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { EXPECT_THAT(MessageFieldEquals(lhs_message->GetReflection()->GetMessage( *lhs_message, lhs_field), *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); EXPECT_THAT(MessageFieldEquals(*rhs_message, rhs_field, lhs_message->GetReflection()->GetMessage( *lhs_message, lhs_field), pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); } if (!rhs_field->is_repeated() && rhs_field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { @@ -485,14 +487,16 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { *rhs_message, rhs_field), pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); EXPECT_THAT(MessageFieldEquals(rhs_message->GetReflection()->GetMessage( *rhs_message, rhs_field), *lhs_message, lhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << lhs_field->name() << " " << *rhs_message - << " " << rhs_field->name(); + << lhs_message->ShortDebugString() << " " << lhs_field->name() + << " " << rhs_message->ShortDebugString() << " " + << rhs_field->name(); } // Test `google.protobuf.Any`. absl::optional, @@ -505,21 +509,24 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { EXPECT_THAT(MessageFieldEquals(*lhs_any->first, lhs_any->second, *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any->first << " " << *rhs_message; + << lhs_any->first->ShortDebugString() << " " + << rhs_message->ShortDebugString(); if (!lhs_any->second->is_repeated()) { EXPECT_THAT( MessageFieldEquals(lhs_any->first->GetReflection()->GetMessage( *lhs_any->first, lhs_any->second), *rhs_message, rhs_field, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any->first << " " << *rhs_message; + << lhs_any->first->ShortDebugString() << " " + << rhs_message->ShortDebugString(); } } if (rhs_any) { EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, *rhs_any->first, rhs_any->second, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << *rhs_any->first; + << lhs_message->ShortDebugString() << " " + << rhs_any->first->ShortDebugString(); if (!rhs_any->second->is_repeated()) { EXPECT_THAT( MessageFieldEquals(*lhs_message, lhs_field, @@ -527,7 +534,8 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { *rhs_any->first, rhs_any->second), pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_message << " " << *rhs_any->first; + << lhs_message->ShortDebugString() << " " + << rhs_any->first->ShortDebugString(); } } if (lhs_any && rhs_any) { @@ -535,7 +543,8 @@ TEST_P(UnaryMessageFieldEqualsTest, Equals) { MessageFieldEquals(*lhs_any->first, lhs_any->second, *rhs_any->first, rhs_any->second, pool, factory), IsOkAndHolds(test_case.equal)) - << *lhs_any->first << " " << *rhs_any->second; + << lhs_any->first->ShortDebugString() << " " + << rhs_any->first->ShortDebugString(); } } } diff --git a/parser/parser_test.cc b/parser/parser_test.cc index aee121051..c96845e67 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1495,14 +1495,16 @@ TEST_P(ExpressionTest, Parse) { KindAndIdAdorner kind_and_id_adorner; ExprPrinter w(kind_and_id_adorner); std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.P, adorned_string) << result->parsed_expr(); + EXPECT_EQ(test_info.P, adorned_string) + << result->parsed_expr().ShortDebugString(); } if (!test_info.L.empty()) { LocationAdorner location_adorner(result->parsed_expr().source_info()); ExprPrinter w(location_adorner); std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.L, adorned_string) << result->parsed_expr(); + EXPECT_EQ(test_info.L, adorned_string) + << result->parsed_expr().ShortDebugString(); ; } @@ -1514,7 +1516,7 @@ TEST_P(ExpressionTest, Parse) { if (!test_info.M.empty()) { EXPECT_EQ(test_info.M, ConvertMacroCallsToString( result.value().parsed_expr().source_info())) - << result->parsed_expr(); + << result->parsed_expr().ShortDebugString(); ; } } @@ -1867,14 +1869,16 @@ TEST_P(UpdatedAccuVarDisabledTest, Parse) { KindAndIdAdorner kind_and_id_adorner; ExprPrinter w(kind_and_id_adorner); std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.P, adorned_string) << result->parsed_expr(); + EXPECT_EQ(test_info.P, adorned_string) + << result->parsed_expr().ShortDebugString(); } if (!test_info.L.empty()) { LocationAdorner location_adorner(result->parsed_expr().source_info()); ExprPrinter w(location_adorner); std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.L, adorned_string) << result->parsed_expr(); + EXPECT_EQ(test_info.L, adorned_string) + << result->parsed_expr().ShortDebugString(); } if (!test_info.R.empty()) { @@ -1885,7 +1889,7 @@ TEST_P(UpdatedAccuVarDisabledTest, Parse) { if (!test_info.M.empty()) { EXPECT_EQ(test_info.M, ConvertMacroCallsToString( result.value().parsed_expr().source_info())) - << result->parsed_expr(); + << result->parsed_expr().ShortDebugString(); } } From 474683e6a183629bd76dafd2526a2b8dbaeae0be Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 17 Apr 2026 18:04:04 -0700 Subject: [PATCH 43/88] Use explicit numeric types in tests PiperOrigin-RevId: 901553198 --- eval/public/builtin_func_test.cc | 149 +++++++++++++++++++------------ 1 file changed, 92 insertions(+), 57 deletions(-) diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index dba71d307..037fa8345 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -544,13 +544,14 @@ TEST_F(BuiltinsTest, TestDurationFunctions) { ref.set_seconds(93541L); ref.set_nanos(11000000L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), 25L); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), + int64_t{25L}); TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateDuration(&ref), - 1559L); + int64_t{1559L}); TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateDuration(&ref), - 93541L); + int64_t{93541L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateDuration(&ref), - 11L); + int64_t{11L}); std::string result = "93541.011s"; TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), @@ -560,13 +561,14 @@ TEST_F(BuiltinsTest, TestDurationFunctions) { ref.set_seconds(-93541L); ref.set_nanos(-11000000L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), -25L); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateDuration(&ref), + int64_t{-25L}); TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateDuration(&ref), - -1559L); + int64_t{-1559L}); TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateDuration(&ref), - -93541L); + int64_t{-93541L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateDuration(&ref), - -11L); + int64_t{-11L}); result = "-93541.011s"; TestTypeConverts(builtin::kString, CelProtoWrapper::CreateDuration(&ref), @@ -595,23 +597,28 @@ TEST_F(BuiltinsTest, TestTimestampFunctions) { ref.set_seconds(1L); ref.set_nanos(11000000L); TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), - 1970L); - TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), 0L); + int64_t{1970L}); + TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{0L}); TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), - 0L); + int64_t{0L}); TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), - 0L); - TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), 1L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), 0L); - TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), 0L); - TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), 1L); + int64_t{0L}); + TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{1L}); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{0L}); + TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{0L}); + TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{1L}); TestFunctions(builtin::kMilliseconds, CelProtoWrapper::CreateTimestamp(&ref), - 11L); + int64_t{11L}); ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), - 0L); + int64_t{0L}); } TEST_F(BuiltinsTest, TestTimestampConversionToString) { @@ -640,46 +647,60 @@ TEST_F(BuiltinsTest, TestTimestampFunctionsWithTimeZone) { TestFunctionsWithParams(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), params, - 1969L); + int64_t{1969L}); TestFunctionsWithParams(builtin::kMonth, - CelProtoWrapper::CreateTimestamp(&ref), params, 11L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); TestFunctionsWithParams(builtin::kDayOfYear, - CelProtoWrapper::CreateTimestamp(&ref), params, 364L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{364L}); TestFunctionsWithParams(builtin::kDayOfMonth, - CelProtoWrapper::CreateTimestamp(&ref), params, 30L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{30L}); TestFunctionsWithParams(builtin::kDate, - CelProtoWrapper::CreateTimestamp(&ref), params, 31L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{31L}); TestFunctionsWithParams(builtin::kHours, - CelProtoWrapper::CreateTimestamp(&ref), params, 16L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{16L}); TestFunctionsWithParams(builtin::kMinutes, - CelProtoWrapper::CreateTimestamp(&ref), params, 0L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{0L}); TestFunctionsWithParams(builtin::kSeconds, - CelProtoWrapper::CreateTimestamp(&ref), params, 1L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{1L}); TestFunctionsWithParams(builtin::kMilliseconds, - CelProtoWrapper::CreateTimestamp(&ref), params, 11L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctionsWithParams(builtin::kDayOfWeek, - CelProtoWrapper::CreateTimestamp(&ref), params, 6L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{6L}); // Test timestamp functions with negative value ref.set_seconds(-1L); ref.set_nanos(0L); TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), - 1969L); - TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), 11L); + int64_t{1969L}); + TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{11L}); TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), - 364L); + int64_t{364L}); TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), - 30L); - TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), 31L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), 23L); - TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), 59L); - TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), 59L); + int64_t{30L}); + TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{31L}); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{23L}); + TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); + TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), - 3L); + int64_t{3L}); // Test timestamp functions w/ fixed timezone ref.set_seconds(1L); @@ -690,46 +711,60 @@ TEST_F(BuiltinsTest, TestTimestampFunctionsWithTimeZone) { TestFunctionsWithParams(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), params, - 1969L); + int64_t{1969L}); TestFunctionsWithParams(builtin::kMonth, - CelProtoWrapper::CreateTimestamp(&ref), params, 11L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); TestFunctionsWithParams(builtin::kDayOfYear, - CelProtoWrapper::CreateTimestamp(&ref), params, 364L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{364L}); TestFunctionsWithParams(builtin::kDayOfMonth, - CelProtoWrapper::CreateTimestamp(&ref), params, 30L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{30L}); TestFunctionsWithParams(builtin::kDate, - CelProtoWrapper::CreateTimestamp(&ref), params, 31L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{31L}); TestFunctionsWithParams(builtin::kHours, - CelProtoWrapper::CreateTimestamp(&ref), params, 16L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{16L}); TestFunctionsWithParams(builtin::kMinutes, - CelProtoWrapper::CreateTimestamp(&ref), params, 0L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{0L}); TestFunctionsWithParams(builtin::kSeconds, - CelProtoWrapper::CreateTimestamp(&ref), params, 1L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{1L}); TestFunctionsWithParams(builtin::kMilliseconds, - CelProtoWrapper::CreateTimestamp(&ref), params, 11L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{11L}); ref.set_seconds(259200L); ref.set_nanos(0L); TestFunctionsWithParams(builtin::kDayOfWeek, - CelProtoWrapper::CreateTimestamp(&ref), params, 6L); + CelProtoWrapper::CreateTimestamp(&ref), params, + int64_t{6L}); // Test timestamp functions with negative value ref.set_seconds(-1L); ref.set_nanos(0L); TestFunctions(builtin::kFullYear, CelProtoWrapper::CreateTimestamp(&ref), - 1969L); - TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), 11L); + int64_t{1969L}); + TestFunctions(builtin::kMonth, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{11L}); TestFunctions(builtin::kDayOfYear, CelProtoWrapper::CreateTimestamp(&ref), - 364L); + int64_t{364L}); TestFunctions(builtin::kDayOfMonth, CelProtoWrapper::CreateTimestamp(&ref), - 30L); - TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), 31L); - TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), 23L); - TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), 59L); - TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), 59L); + int64_t{30L}); + TestFunctions(builtin::kDate, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{31L}); + TestFunctions(builtin::kHours, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{23L}); + TestFunctions(builtin::kMinutes, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); + TestFunctions(builtin::kSeconds, CelProtoWrapper::CreateTimestamp(&ref), + int64_t{59L}); TestFunctions(builtin::kDayOfWeek, CelProtoWrapper::CreateTimestamp(&ref), - 3L); + int64_t{3L}); TestTypeConversionError( builtin::kString, @@ -828,7 +863,7 @@ TEST_F(BuiltinsTest, TestIntConversions_doubleIntMinMinus1024) { TEST_F(BuiltinsTest, TestIntConversionError_doubleIntMaxMinus512) { // Converting int64_t max - 512 to a double will not roundtrip to the original - // value, but it will rountrip to a valid 64-bit integer. + // value, but it will roundtrip to a valid 64-bit integer. double range = std::numeric_limits::max() - 512; TestTypeConverts(builtin::kInt, CelValue::CreateDouble(range), int64_t{std::numeric_limits::max() - 1023}); From 1629bb4739a26d3cdbbbe7d2f964121c241e81ed Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 20 Apr 2026 09:08:41 -0700 Subject: [PATCH 44/88] Update alloc macros in internal/new.cc PiperOrigin-RevId: 902667189 --- internal/new.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/internal/new.cc b/internal/new.cc index 5bd9e8158..05396e624 100644 --- a/internal/new.cc +++ b/internal/new.cc @@ -67,6 +67,13 @@ void* AlignedNew(size_t size, std::align_val_t alignment) { ThrowStdBadAlloc(); } return ptr; +#elif defined(__APPLE__) + void* ptr; + if (ABSL_PREDICT_FALSE( + posix_memalign(&ptr, static_cast(alignment), size) != 0)) { + ThrowStdBadAlloc(); + } + return ptr; #else void* ptr = std::aligned_alloc(static_cast(alignment), size); if (ABSL_PREDICT_FALSE(size != 0 && ptr == nullptr)) { @@ -107,7 +114,7 @@ void AlignedDelete(void* ptr, std::align_val_t alignment) noexcept { ::operator delete(ptr, alignment); #else if (static_cast(alignment) <= kDefaultNewAlignment) { - Delete(ptr, size); + SizedDelete(ptr, size); } else { #if defined(_MSC_VER) _aligned_free(ptr); From 8c140223c1262ff9481d9bdef70fd020d154fcd1 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 20 Apr 2026 09:28:15 -0700 Subject: [PATCH 45/88] Update Bazel build rules for MacOS PiperOrigin-RevId: 902675874 --- .bazelrc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.bazelrc b/.bazelrc index a6d7a13f0..475706072 100644 --- a/.bazelrc +++ b/.bazelrc @@ -15,6 +15,13 @@ build:msvc --define=protobuf_allow_msvc=true build:msvc --test_tag_filters=-benchmark,-notap,-no_test_msvc build:msvc --build_tag_filters=-no_test_msvc +build:macos --cxxopt=-faligned-allocation +build:macos --cxxopt=-mmacosx-version-min=10.13 +build:macos --linkopt=-mmacosx-version-min=10.13 + +# ANTLR tool requires Java 17+. +build --java_runtime_version=remotejdk_17 + test --test_output=errors # Enable matchers in googletest From c1e3026841f93b01774264e2c7171d7e55549119 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 20 Apr 2026 13:33:49 -0700 Subject: [PATCH 46/88] Make error handling in wrapped legacy map consistent with modern version. The wrapped legacy implementation would use absl status return to propagate expected errors leading to inconsistencies in error handling depending on whether the underlying map was backed by a modern or legacy map. Update so that expected errors are cel::ErrorValue and unexpected errors (internal bug or client error) are returned as non-ok Status return value. Add Documentation for MapValue members. PiperOrigin-RevId: 902798507 --- common/legacy_value.cc | 17 ++-- common/values/custom_map_value.h | 16 ++-- common/values/legacy_map_value.h | 12 +-- common/values/map_value.h | 55 ++++++++---- common/values/null_value.h | 3 +- common/values/parsed_json_map_value.h | 12 +-- common/values/parsed_map_field_value.h | 12 +-- common/values/values.h | 1 - eval/eval/equality_steps.cc | 32 +++---- eval/public/builtin_func_test.cc | 3 +- eval/public/cel_type_registry.h | 6 -- eval/public/structs/field_access_impl.cc | 6 +- .../proto_message_type_adapter_test.cc | 4 +- .../container_membership_functions.cc | 90 ++++++++++--------- 14 files changed, 142 insertions(+), 127 deletions(-) diff --git a/common/legacy_value.cc b/common/legacy_value.cc index 5c81fdacb..7fbf16732 100644 --- a/common/legacy_value.cc +++ b/common/legacy_value.cc @@ -700,7 +700,8 @@ absl::Status LegacyMapValue::Get( case ValueKind::kString: break; default: - return InvalidMapKeyTypeError(key.kind()); + *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); + return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); auto cel_value = impl_->Get(arena, cel_key); @@ -732,7 +733,7 @@ absl::StatusOr LegacyMapValue::Find( case ValueKind::kString: break; default: - return InvalidMapKeyTypeError(key.kind()); + *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); auto cel_value = impl_->Get(arena, cel_key); @@ -764,11 +765,17 @@ absl::Status LegacyMapValue::Has( case ValueKind::kString: break; default: - return InvalidMapKeyTypeError(key.kind()); + *result = ErrorValue(InvalidMapKeyTypeError(key.kind())); + return absl::OkStatus(); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); - CEL_ASSIGN_OR_RETURN(auto has, impl_->Has(cel_key)); - *result = BoolValue{has}; + absl::StatusOr has = impl_->Has(cel_key); + if (!has.ok()) { + *result = ErrorValue(std::move(has).status()); + return absl::OkStatus(); + } + + *result = BoolValue(*has); return absl::OkStatus(); } diff --git a/common/values/custom_map_value.h b/common/values/custom_map_value.h index 9e840e07f..ca6e1e025 100644 --- a/common/values/custom_map_value.h +++ b/common/values/custom_map_value.h @@ -225,7 +225,7 @@ class CustomMapValueInterface { // Returns the number of entries in this map. virtual size_t Size() const = 0; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. virtual absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -233,7 +233,7 @@ class CustomMapValueInterface { google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const = 0; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. virtual absl::Status ForEach( ForEachCallback callback, @@ -347,7 +347,7 @@ class CustomMapValue final size_t Size() const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -356,7 +356,7 @@ class CustomMapValue final Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, @@ -365,7 +365,7 @@ class CustomMapValue final google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -374,7 +374,7 @@ class CustomMapValue final Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -386,7 +386,7 @@ class CustomMapValue final // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, @@ -394,7 +394,7 @@ class CustomMapValue final google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr NewIterator() const; diff --git a/common/values/legacy_map_value.h b/common/values/legacy_map_value.h index 31865a873..c83b7fc2f 100644 --- a/common/values/legacy_map_value.h +++ b/common/values/legacy_map_value.h @@ -102,7 +102,7 @@ class LegacyMapValue final size_t Size() const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -111,7 +111,7 @@ class LegacyMapValue final Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, @@ -120,7 +120,7 @@ class LegacyMapValue final google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -129,7 +129,7 @@ class LegacyMapValue final Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -137,11 +137,11 @@ class LegacyMapValue final google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; - // See the corresponding type declaration of `MapValueInterface` for + // See the corresponding type declaration of `MapValue` for // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, diff --git a/common/values/map_value.h b/common/values/map_value.h index ffbdea6c9..b6e69ea57 100644 --- a/common/values/map_value.h +++ b/common/values/map_value.h @@ -15,10 +15,16 @@ // IWYU pragma: private, include "common/value.h" // IWYU pragma: friend "common/value.h" -// `MapValue` represents values of the primitive `map` type. `MapValueView` -// is a non-owning view of `MapValue`. `MapValueInterface` is the abstract -// base class of implementations. `MapValue` and `MapValueView` act as smart -// pointers to `MapValueInterface`. +// `MapValue` represents values of the primitive `map` type. It provides a +// unified interface for accessing map contents, regardless of the underlying +// implementation (e.g., JSON, protobuf map field, or custom implementation). +// +// Public member functions: +// - `IsEmpty()` / `Size()`: Query map size. +// - `Get()` / `Find()` / `Has()`: Access entries by key. +// - `ListKeys()` / `NewIterator()` / `ForEach()`: Iterate over entries. +// - `ConvertToJson()` / `ConvertToJsonObject()`: JSON conversion. +// - `IsCustom()` / `AsCustom()` / `GetCustom()`: Access custom implementation. #ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_H_ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_MAP_VALUE_H_ @@ -54,7 +60,6 @@ namespace cel { -class MapValueInterface; class MapValue; class Value; @@ -119,8 +124,13 @@ class MapValue final : private common_internal::MapValueMixin { absl::StatusOr Size() const; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `Get` sets the value `result` to (via `result`) the value associated with + // `key`. If `key` is not found, `no such key` is set to `result`. If an error + // occurs (e.g., invalid key type), an `no such key` is returned. + // + // A non-ok status may be returned if an unexpected error is encountered or to + // propagate an error from a custom implementation, in which case `result` is + // unspecified. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, @@ -128,8 +138,13 @@ class MapValue final : private common_internal::MapValueMixin { Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `Find` returns `true` if `key` is found in the map, and stores the + // associated value in `result`. If `key` is not found, `false` is returned + // and `result` is unchanged. + // + // A non-ok status may be returned if an unexpected error is encountered or to + // propagate an error from a custom implementation, in which case `result` is + // unspecified. absl::StatusOr Find( const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -137,8 +152,13 @@ class MapValue final : private common_internal::MapValueMixin { google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `Has` returns `true` if `key` is found in the map, and stores the BoolValue + // result in `result`. In case of an error, the result is set to an + // ErrorValue. + // + // A non-ok status may be returned if an unexpected error is encountered or to + // propagate an error from a custom implementation, in which case `result` is + // unspecified. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, @@ -146,28 +166,25 @@ class MapValue final : private common_internal::MapValueMixin { Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `ListKeys` returns a `ListValue` containing all keys in the map. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; - // See the corresponding type declaration of `MapValueInterface` for - // documentation. + // `ForEachCallback` is the callback type for `ForEach`. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `ForEach` calls `callback` for each entry in the map. Iteration continues + // until all entries are visited or `callback` returns an error or `false`. absl::Status ForEach( ForEachCallback callback, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) const; - // See the corresponding member function of `MapValueInterface` for - // documentation. + // `NewIterator` returns a new iterator for the map. absl::StatusOr NewIterator() const; // Returns `true` if this value is an instance of a custom map value. diff --git a/common/values/null_value.h b/common/values/null_value.h index 53c3161a1..d4d05dba3 100644 --- a/common/values/null_value.h +++ b/common/values/null_value.h @@ -37,8 +37,7 @@ namespace cel { class Value; class NullValue; -// `NullValue` represents values of the primitive `duration` type. - +// `NullValue` represents the CEL `null` value. class NullValue final : private common_internal::ValueMixin { public: static constexpr ValueKind kKind = ValueKind::kNull; diff --git a/common/values/parsed_json_map_value.h b/common/values/parsed_json_map_value.h index b20fe032b..ba8d3490d 100644 --- a/common/values/parsed_json_map_value.h +++ b/common/values/parsed_json_map_value.h @@ -132,7 +132,7 @@ class ParsedJsonMapValue final size_t Size() const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -141,7 +141,7 @@ class ParsedJsonMapValue final Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, @@ -150,7 +150,7 @@ class ParsedJsonMapValue final google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -159,7 +159,7 @@ class ParsedJsonMapValue final Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -167,11 +167,11 @@ class ParsedJsonMapValue final google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; - // See the corresponding type declaration of `MapValueInterface` for + // See the corresponding type declaration of `MapValue` for // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, diff --git a/common/values/parsed_map_field_value.h b/common/values/parsed_map_field_value.h index 3478f75bc..21d686bfd 100644 --- a/common/values/parsed_map_field_value.h +++ b/common/values/parsed_map_field_value.h @@ -117,7 +117,7 @@ class ParsedMapFieldValue final size_t Size() const; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Get(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -126,7 +126,7 @@ class ParsedMapFieldValue final Value* absl_nonnull result) const; using MapValueMixin::Get; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::StatusOr Find( const Value& key, @@ -135,7 +135,7 @@ class ParsedMapFieldValue final google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const; using MapValueMixin::Find; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status Has(const Value& key, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -144,7 +144,7 @@ class ParsedMapFieldValue final Value* absl_nonnull result) const; using MapValueMixin::Has; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ListKeys( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, @@ -152,11 +152,11 @@ class ParsedMapFieldValue final google::protobuf::Arena* absl_nonnull arena, ListValue* absl_nonnull result) const; using MapValueMixin::ListKeys; - // See the corresponding type declaration of `MapValueInterface` for + // See the corresponding type declaration of `MapValue` for // documentation. using ForEachCallback = typename CustomMapValueInterface::ForEachCallback; - // See the corresponding member function of `MapValueInterface` for + // See the corresponding member function of `MapValue` for // documentation. absl::Status ForEach( ForEachCallback callback, diff --git a/common/values/values.h b/common/values/values.h index c9703dcbb..aaa6f8659 100644 --- a/common/values/values.h +++ b/common/values/values.h @@ -48,7 +48,6 @@ namespace cel { class ValueInterface; class ListValueInterface; -class MapValueInterface; class StructValueInterface; class Value; diff --git a/eval/eval/equality_steps.cc b/eval/eval/equality_steps.cc index e134069d5..d720302e4 100644 --- a/eval/eval/equality_steps.cc +++ b/eval/eval/equality_steps.cc @@ -132,15 +132,11 @@ class IterativeEqualityStep : public ExpressionStepBase { absl::StatusOr EvaluateInMap(ExecutionFrameBase& frame, const Value& item, const MapValue& container) { - absl::StatusOr result = {BoolValue(false)}; switch (item.kind()) { case ValueKind::kBool: case ValueKind::kString: case ValueKind::kInt: case ValueKind::kUint: - result = container.Has(item, frame.descriptor_pool(), - frame.message_factory(), frame.arena()); - break; case ValueKind::kDouble: break; default: @@ -148,9 +144,12 @@ absl::StatusOr EvaluateInMap(ExecutionFrameBase& frame, cel::runtime_internal::CreateNoMatchingOverloadError( cel::builtin::kIn)); } + Value result; + CEL_RETURN_IF_ERROR(container.Has(item, frame.descriptor_pool(), + frame.message_factory(), frame.arena(), + &result)); - if (result.ok() && result.value().IsBool() && - result.value().GetBool().NativeValue()) { + if (result.IsTrue()) { return result; } @@ -159,10 +158,10 @@ absl::StatusOr EvaluateInMap(ExecutionFrameBase& frame, ? Number::FromDouble(item.GetDouble().NativeValue()) : Number::FromUint64(item.GetUint().NativeValue()); if (number.LosslessConvertibleToInt()) { - result = container.Has(IntValue(number.AsInt()), frame.descriptor_pool(), - frame.message_factory(), frame.arena()); - if (result.ok() && result.value().IsBool() && - result.value().GetBool().NativeValue()) { + CEL_RETURN_IF_ERROR( + container.Has(IntValue(number.AsInt()), frame.descriptor_pool(), + frame.message_factory(), frame.arena(), &result)); + if (result.IsTrue()) { return result; } } @@ -173,21 +172,16 @@ absl::StatusOr EvaluateInMap(ExecutionFrameBase& frame, ? Number::FromDouble(item.GetDouble().NativeValue()) : Number::FromInt64(item.GetInt().NativeValue()); if (number.LosslessConvertibleToUint()) { - result = + CEL_RETURN_IF_ERROR( container.Has(UintValue(number.AsUint()), frame.descriptor_pool(), - frame.message_factory(), frame.arena()); - if (result.ok() && result.value().IsBool() && - result.value().GetBool().NativeValue()) { + frame.message_factory(), frame.arena(), &result)); + if (result.IsTrue()) { return result; } } } - if (!result.ok()) { - return BoolValue(false); - } - - return result; + return BoolValue(false); } absl::StatusOr EvaluateIn(ExecutionFrameBase& frame, const Value& item, diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 037fa8345..b73a2dc55 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -1632,7 +1632,8 @@ TEST_F(BuiltinsTest, TestMapInError) { CelValue result_value; ASSERT_NO_FATAL_FAILURE(PerformRun( builtin::kIn, {}, {key, CelValue::CreateMap(&cel_map)}, &result_value)); - EXPECT_TRUE(result_value.IsBool()); + ASSERT_TRUE(result_value.IsBool()) + << key.DebugString() << " : " << result_value.DebugString(); EXPECT_FALSE(result_value.BoolOrDie()); } diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h index 0c01eb8e9..3fb80bcea 100644 --- a/eval/public/cel_type_registry.h +++ b/eval/public/cel_type_registry.h @@ -28,7 +28,6 @@ #include "base/type_provider.h" #include "eval/public/structs/legacy_type_adapter.h" #include "eval/public/structs/legacy_type_provider.h" -#include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "runtime/type_registry.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -139,11 +138,6 @@ class CelTypeRegistry { private: // Internal modern registry. cel::TypeRegistry modern_type_registry_; - - // TODO(uncreated-issue/44): This is needed to inspect the registered legacy type - // providers for client tests. This can be removed when they are migrated to - // use the modern APIs. - std::shared_ptr legacy_type_provider_; }; } // namespace google::api::expr::runtime diff --git a/eval/public/structs/field_access_impl.cc b/eval/public/structs/field_access_impl.cc index 3b3cb9847..2bd9fff9d 100644 --- a/eval/public/structs/field_access_impl.cc +++ b/eval/public/structs/field_access_impl.cc @@ -139,8 +139,7 @@ class FieldAccessor { case FieldDescriptor::TYPE_BYTES: return CelValue::CreateBytesView(value); default: - return absl::Status(absl::StatusCode::kInvalidArgument, - "Error handling C++ string conversion"); + break; } break; } @@ -153,8 +152,7 @@ class FieldAccessor { return CelValue::CreateInt64(enum_value); } default: - return absl::Status(absl::StatusCode::kInvalidArgument, - "Unhandled C++ type conversion"); + break; } return absl::Status(absl::StatusCode::kInvalidArgument, "Unhandled C++ type conversion"); diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index 088d20d48..32608bc3f 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -69,8 +69,8 @@ class ProtoMessageTypeAccessorTest : public testing::TestWithParam { bool use_generic_instance = GetParam(); if (use_generic_instance) { // implementation detail: in general, type info implementations may - // return a different accessor object based on the messsage instance, but - // this implemenation returns the same one no matter the message. + // return a different accessor object based on the message instance, but + // this implementation returns the same one no matter the message. return *GetGenericProtoTypeInfoInstance().GetAccessApis(dummy_); } else { diff --git a/runtime/standard/container_membership_functions.cc b/runtime/standard/container_membership_functions.cc index 9f5ca3755..cc0638429 100644 --- a/runtime/standard/container_membership_functions.cc +++ b/runtime/standard/container_membership_functions.cc @@ -174,15 +174,16 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { - auto result = - map_value.Has(BoolValue(key), descriptor_pool, message_factory, arena); - if (result.ok()) { - return std::move(*result); + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(BoolValue(key), descriptor_pool, + message_factory, arena, &has)); + if (has.IsTrue()) { + return has; } if (enable_heterogeneous_equality) { return BoolValue(false); } - return ErrorValue(result.status()); + return has; }; auto intKeyInSet = @@ -191,27 +192,26 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { - auto result = - map_value.Has(IntValue(key), descriptor_pool, message_factory, arena); + Value result; + CEL_RETURN_IF_ERROR(map_value.Has(IntValue(key), descriptor_pool, + message_factory, arena, &result)); if (enable_heterogeneous_equality) { - if (result.ok() && result->IsTrue()) { - return std::move(*result); + if (result.IsTrue()) { + return result; } Number number = Number::FromInt64(key); if (number.LosslessConvertibleToUint()) { - const auto& result = - map_value.Has(UintValue(number.AsUint()), descriptor_pool, - message_factory, arena); - if (result.ok() && result->IsTrue()) { - return std::move(*result); + Value result_alt; + CEL_RETURN_IF_ERROR(map_value.Has(UintValue(number.AsUint()), + descriptor_pool, message_factory, + arena, &result_alt)); + if (result_alt.IsTrue()) { + return result_alt; } } return BoolValue(false); } - if (!result.ok()) { - return ErrorValue(result.status()); - } - return std::move(*result); + return result; }; auto stringKeyInSet = @@ -220,14 +220,16 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { - auto result = map_value.Has(key, descriptor_pool, message_factory, arena); - if (result.ok()) { - return std::move(*result); + Value result; + CEL_RETURN_IF_ERROR( + map_value.Has(key, descriptor_pool, message_factory, arena, &result)); + if (result.IsBool()) { + return result; } if (enable_heterogeneous_equality) { return BoolValue(false); } - return ErrorValue(result.status()); + return result; }; auto uintKeyInSet = @@ -236,26 +238,26 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { - const auto& result = - map_value.Has(UintValue(key), descriptor_pool, message_factory, arena); + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(UintValue(key), descriptor_pool, + message_factory, arena, &has)); if (enable_heterogeneous_equality) { - if (result.ok() && result->IsTrue()) { - return std::move(*result); + if (has.IsTrue()) { + return has; } + Value has_alt; Number number = Number::FromUint64(key); if (number.LosslessConvertibleToInt()) { - const auto& result = map_value.Has( - IntValue(number.AsInt()), descriptor_pool, message_factory, arena); - if (result.ok() && result->IsTrue()) { - return std::move(*result); + CEL_RETURN_IF_ERROR(map_value.Has(IntValue(number.AsInt()), + descriptor_pool, message_factory, + arena, &has_alt)); + if (has.IsTrue()) { + return has; } } return BoolValue(false); } - if (!result.ok()) { - return ErrorValue(result.status()); - } - return std::move(*result); + return has; }; auto doubleKeyInSet = @@ -265,17 +267,21 @@ absl::Status RegisterMapMembershipFunctions(FunctionRegistry& registry, google::protobuf::Arena* absl_nonnull arena) -> absl::StatusOr { Number number = Number::FromDouble(key); if (number.LosslessConvertibleToInt()) { - const auto& result = map_value.Has( - IntValue(number.AsInt()), descriptor_pool, message_factory, arena); - if (result.ok() && result->IsTrue()) { - return std::move(*result); + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(IntValue(number.AsInt()), + descriptor_pool, message_factory, arena, + &has)); + if (has.IsTrue()) { + return has; } } if (number.LosslessConvertibleToUint()) { - const auto& result = map_value.Has( - UintValue(number.AsUint()), descriptor_pool, message_factory, arena); - if (result.ok() && result->IsTrue()) { - return std::move(*result); + Value has; + CEL_RETURN_IF_ERROR(map_value.Has(UintValue(number.AsUint()), + descriptor_pool, message_factory, arena, + &has)); + if (has.IsTrue()) { + return has; } } return BoolValue(false); From c2c1205482616c260313eda6b506cde5f5dc04f7 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 21 Apr 2026 16:03:30 -0700 Subject: [PATCH 47/88] Add ExpressionContainer to hold namespace config ExpressionContainer will hold configuration for namespace resolution rules in a given expression, adding support for abbrevs (imports) and aliases. PiperOrigin-RevId: 903474367 --- common/BUILD | 23 +++++ common/container.cc | 189 +++++++++++++++++++++++++++++++++++++++ common/container.h | 116 ++++++++++++++++++++++++ common/container_test.cc | 103 +++++++++++++++++++++ 4 files changed, 431 insertions(+) create mode 100644 common/container.cc create mode 100644 common/container.h create mode 100644 common/container_test.cc diff --git a/common/BUILD b/common/BUILD index e289ef413..ea6246b51 100644 --- a/common/BUILD +++ b/common/BUILD @@ -1140,3 +1140,26 @@ cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "container", + srcs = ["container.cc"], + hdrs = ["container.h"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "container_test", + srcs = ["container_test.cc"], + deps = [ + ":container", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + ], +) diff --git a/common/container.cc b/common/container.cc new file mode 100644 index 000000000..4abceea2d --- /dev/null +++ b/common/container.cc @@ -0,0 +1,189 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/container.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace cel { +namespace { + +// Basic validation for accidental misuse. Does not fully validate against the +// CEL grammar rules for identifiers. +bool IsIdentifierChar(char c) { + return c == '_' || std::isalnum(static_cast(c)); +} + +bool IsValidQualifiedName(absl::string_view name) { + bool dot_ok = false; + for (char c : name) { + if (c == '.') { + if (!dot_ok) { + return false; + } + dot_ok = false; + continue; + } + if (!IsIdentifierChar(c)) { + return false; + } + dot_ok = true; + } + // Must not end in a dot. + return dot_ok; +} + +bool IsValidAlias(absl::string_view alias) { + if (alias.empty()) { + return false; + } + for (char c : alias) { + if (!IsIdentifierChar(c)) { + return false; + } + } + return true; +} + +bool IsAbreviation(absl::string_view alias, absl::string_view name) { + auto pos = name.rfind('.'); + return pos != std::string::npos && pos > 0 && pos < name.size() - 1 && + alias == name.substr(pos + 1); +} + +} // namespace + +bool ExpressionContainer::AliasListing::IsAbbreviation() const { + return IsAbreviation(alias, name); +} + +absl::StatusOr ExpressionContainer::Create( + absl::string_view name) { + ExpressionContainer container; + + absl::Status status = container.SetContainer(name); + if (!status.ok()) { + return status; + } + return container; +} + +absl::Status ExpressionContainer::SetContainer(absl::string_view name) { + if (name.empty()) { + container_ = ""; + return absl::OkStatus(); + } + + if (!IsValidQualifiedName(name)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", name)); + } + + for (const auto& entry : aliases_) { + const std::string& alias = entry.first; + if (name == alias || + (name.size() > alias.size() && + absl::string_view(name).substr(0, alias.size()) == alias && + name.at(alias.size()) == '.')) { + return absl::InvalidArgumentError( + absl::StrCat("container name collides with alias: ", alias)); + } + } + + container_ = std::string(name); + return absl::OkStatus(); +} + +absl::Status ExpressionContainer::AddAbbreviation(absl::string_view abrev) { + if (!IsValidQualifiedName(abrev)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", abrev)); + } + + auto pos = abrev.rfind('.'); + if (pos == 0 || pos == absl::string_view::npos || pos == abrev.size() - 1) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", abrev, + ", wanted name of the form 'qualified.name'")); + } + + absl::string_view alias = abrev.substr(pos + 1); + return AddAlias(alias, abrev); +} + +absl::Status ExpressionContainer::AddAlias(absl::string_view alias, + absl::string_view name) { + if (!IsValidAlias(alias)) { + return absl::InvalidArgumentError(absl::StrCat( + "alias must be non-empty and simple (not qualified): ", alias)); + } + + if (!IsValidQualifiedName(name)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid qualified name: ", name)); + } + + if (auto it = aliases_.find(alias); it != aliases_.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "alias collides with existing reference: ", alias, " -> ", it->second)); + } + + if (container_ == alias || + (container_.size() > alias.size() && + absl::string_view(container_).substr(0, alias.size()) == alias && + container_.at(alias.size()) == '.')) { + return absl::InvalidArgumentError( + absl::StrCat("alias collides with container name: ", alias)); + } + + aliases_.insert({std::string(alias), std::string(name)}); + return absl::OkStatus(); +} + +absl::string_view ExpressionContainer::FindAlias( + absl::string_view alias) const { + auto it = aliases_.find(alias); + if (it != aliases_.end()) { + return it->second; + } + return ""; +} + +std::vector ExpressionContainer::ListAbbreviations() const { + std::vector res; + for (const auto& entry : aliases_) { + if (IsAbreviation(entry.first, entry.second)) { + res.push_back(entry.second); + } + } + return res; +} + +std::vector +ExpressionContainer::ListAliases() const { + std::vector res; + for (const auto& entry : aliases_) { + res.push_back({entry.first, entry.second}); + } + return res; +} + +} // namespace cel diff --git a/common/container.h b/common/container.h new file mode 100644 index 000000000..a6555a8ac --- /dev/null +++ b/common/container.h @@ -0,0 +1,116 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace cel { + +// ExpressionContainer represents the namespace configuration for a CEL +// expression. +// +// The container defines the default resolution order for names referenced in +// the expression. It generally maps to a protobuf package and follows +// approximately the same resolution rules as protobuf or C++ namespaces. +// +// Aliases declare short names that can be referenced without resolving against +// the scopes defined by the container. For consistency, an alias cannot be +// a prefix of the container name. Aliases are always unqualified identifiers. +// +// An abbreviation is a special case of alias that behaves like an import or +// using declaration in other languages. (pkg.TypeName -> TypeName). +// +// For better traceability, prefer using abbreviations over aliases. +class ExpressionContainer { + public: + struct AliasListing { + std::string alias; + std::string name; + + bool IsAbbreviation() const; + }; + + ExpressionContainer() = default; + + static absl::StatusOr Create(absl::string_view name); + + ExpressionContainer(const ExpressionContainer&) = default; + ExpressionContainer(ExpressionContainer&&) = default; + ExpressionContainer& operator=(const ExpressionContainer&) = default; + ExpressionContainer& operator=(ExpressionContainer&&) = default; + + // Returns the full name of the container. + // + // The default value is an empty string meaning no container. + absl::string_view container() const { return container_; } + + // Sets the container name. + // + // Returns an error if the container name is malformed or conflicts with an + // existing alias. + absl::Status SetContainer(absl::string_view name); + + // Adds an abbreviation to the container. + // + // Returns an error if the abbreviation is malformed or conflicts with the + // container or an existing alias. + absl::Status AddAbbreviation(absl::string_view abrev); + + // Adds an alias to the container. + // + // Returns an error if the alias is malformed or conflicts with the container + // or an existing alias. + absl::Status AddAlias(absl::string_view alias, absl::string_view name); + + // Returns the full name of the alias or an empty string if not found. + // + // The returned string view may be invalidated by updates to the + // ExpressionContainer. + absl::string_view FindAlias(absl::string_view alias) const; + + // Utility method for listing the abbreviations in the container. + // Order is not guaranteed. + std::vector ListAbbreviations() const; + + // Utility method for listing the aliases in the container. + // Includes abbreviations. + // Order is not guaranteed. + std::vector ListAliases() const; + + // Removes all aliases and abbreviations from the container. + void clear() { + container_.clear(); + aliases_.clear(); + } + + private: + explicit ExpressionContainer(absl::string_view name); + + std::string container_; + + // alias -> full name. + absl::flat_hash_map aliases_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ diff --git a/common/container_test.cc b/common/container_test.cc new file mode 100644 index 000000000..d8c052040 --- /dev/null +++ b/common/container_test.cc @@ -0,0 +1,103 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/container.h" + +#include "absl/status/status.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +TEST(ExpressionContainerTest, DefaultConstructed) { + ExpressionContainer container; + EXPECT_THAT(container.container(), IsEmpty()); + EXPECT_THAT(container.FindAlias("foo"), IsEmpty()); +} + +TEST(ExpressionContainerTest, SetContainer) { + ExpressionContainer container; + EXPECT_THAT(container.SetContainer("my.container.name"), IsOk()); + EXPECT_THAT(container.container(), Eq("my.container.name")); + EXPECT_THAT(container.SetContainer("..invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, AddAlias) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + ExpressionContainer::Create("my.container")); + EXPECT_THAT(container.AddAlias("foo", "bar.baz"), IsOk()); + EXPECT_THAT(container.FindAlias("foo"), Eq("bar.baz")); +} + +TEST(ExpressionContainerTest, AddAbbreviation) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + ExpressionContainer::Create("my.container")); + EXPECT_THAT(container.AddAbbreviation("qual.pkg.TypeName"), IsOk()); + EXPECT_THAT(container.FindAlias("TypeName"), Eq("qual.pkg.TypeName")); +} + +TEST(ExpressionContainerTest, ListAbbreviationsAndAliases) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + ExpressionContainer::Create("my.container")); + EXPECT_THAT(container.AddAbbreviation("qual.pkg.Abbr"), IsOk()); + EXPECT_THAT(container.AddAlias("AliasSym", "some.long.name"), IsOk()); + + EXPECT_THAT(container.ListAbbreviations(), + UnorderedElementsAre("qual.pkg.Abbr")); + + auto aliases = container.ListAliases(); + EXPECT_THAT(aliases, SizeIs(2)); +} + +TEST(ExpressionContainerTest, InvalidAbbreviation) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + ExpressionContainer::Create("my.container")); + EXPECT_THAT(container.AddAbbreviation(""), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAbbreviation("pkg"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAbbreviation(".pkg"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAbbreviation("pkg."), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, InvalidAlias) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + ExpressionContainer::Create("my.container")); + EXPECT_THAT(container.AddAlias("", "bar"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAlias("foo.bar", "baz"), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.AddAlias("foo", ".baz"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, CollidesWithContainer) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + ExpressionContainer::Create("my.container")); + EXPECT_THAT(container.AddAlias("my", "bar"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +} // namespace +} // namespace cel From 40176739fca23895ef5b7e1923d24746eeb8d519 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 22 Apr 2026 11:23:30 -0700 Subject: [PATCH 48/88] Wire ExpressionContainer into type checker. Update to Make* convention for factories to match other things in /common. Not used yet. PiperOrigin-RevId: 903948803 --- checker/BUILD | 1 + checker/internal/BUILD | 5 +-- checker/internal/type_check_env.h | 8 ++--- checker/internal/type_checker_builder_impl.cc | 9 +++-- checker/internal/type_checker_builder_impl.h | 6 +++- checker/internal/type_checker_impl.cc | 15 ++++----- checker/internal/type_checker_impl_test.cc | 15 +++++---- checker/type_checker_builder.h | 9 ++++- common/container.cc | 2 +- common/container.h | 29 +++++++++++++--- common/container_test.cc | 33 +++++++++++++++---- 11 files changed, 95 insertions(+), 37 deletions(-) diff --git a/checker/BUILD b/checker/BUILD index d5eb3601c..7b151d6a8 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -88,6 +88,7 @@ cc_library( deps = [ ":checker_options", ":type_checker", + "//common:container", "//common:decl", "//common:type", "@com_google_absl//absl/base:nullability", diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 3f64417a0..73e5c177d 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -67,6 +67,7 @@ cc_library( deps = [ ":descriptor_pool_type_introspector", "//common:constant", + "//common:container", "//common:decl", "//common:type", "//internal:status_macros", @@ -120,7 +121,6 @@ cc_library( "type_checker_impl.h", ], deps = [ - ":descriptor_pool_type_introspector", ":format_type_name", ":namespace_generator", ":type_check_env", @@ -136,9 +136,9 @@ cc_library( "//common:ast_visitor", "//common:ast_visitor_base", "//common:constant", + "//common:container", "//common:decl", "//common:expr", - "//common:source", "//common:type", "//common:type_kind", "//internal:lexis", @@ -172,6 +172,7 @@ cc_test( "//checker:type_check_issue", "//checker:validation_result", "//common:ast", + "//common:container", "//common:decl", "//common:expr", "//common:source", diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 520b0eab6..491e4b550 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -30,6 +30,7 @@ #include "absl/types/span.h" #include "checker/internal/descriptor_pool_type_introspector.h" #include "common/constant.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" @@ -91,7 +92,6 @@ class TypeCheckEnv { absl_nonnull std::shared_ptr descriptor_pool) : descriptor_pool_(std::move(descriptor_pool)), - container_(""), proto_type_introspector_( std::make_shared( descriptor_pool_.get())) { @@ -104,9 +104,9 @@ class TypeCheckEnv { TypeCheckEnv(TypeCheckEnv&&) = default; TypeCheckEnv& operator=(TypeCheckEnv&&) = default; - const std::string& container() const { return container_; } + const ExpressionContainer& container() const { return container_; } - void set_container(std::string container) { + void set_container(ExpressionContainer container) { container_ = std::move(container); } @@ -206,7 +206,7 @@ class TypeCheckEnv { // If set, an arena was needed to allocate types in the environment. absl_nullable std::shared_ptr arena_; - std::string container_; + ExpressionContainer container_; // Used to resolve fields on message types. std::shared_ptr proto_type_introspector_; diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 7545aa949..9ebcb4e34 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -343,7 +343,7 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( absl::StatusOr> TypeCheckerBuilderImpl::Build() { TypeCheckEnv env(descriptor_pool_); - env.set_container(container_); + env.set_container(expression_container_); if (expected_type_.has_value()) { env.set_expected_type(*expected_type_); } @@ -479,7 +479,12 @@ void TypeCheckerBuilderImpl::AddTypeProvider( } void TypeCheckerBuilderImpl::set_container(absl::string_view container) { - container_ = container; + expression_container_.SetContainer(container).IgnoreError(); +} + +void TypeCheckerBuilderImpl::SetExpressionContainer( + ExpressionContainer container) { + expression_container_ = std::move(container); } void TypeCheckerBuilderImpl::SetExpectedType(const Type& type) { diff --git a/checker/internal/type_checker_builder_impl.h b/checker/internal/type_checker_builder_impl.h index 3b3472232..7a099040b 100644 --- a/checker/internal/type_checker_builder_impl.h +++ b/checker/internal/type_checker_builder_impl.h @@ -31,6 +31,7 @@ #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" @@ -76,6 +77,9 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { void set_container(absl::string_view container) override; + void SetExpressionContainer( + ExpressionContainer expression_container) override; + const CheckerOptions& options() const override { return options_; } google::protobuf::Arena* absl_nonnull arena() override { @@ -137,7 +141,7 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { std::vector libraries_; absl::flat_hash_map subsets_; absl::flat_hash_set library_ids_; - std::string container_; + ExpressionContainer expression_container_; absl::optional expected_type_; }; diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 8e8047755..df8f83683 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -187,14 +187,12 @@ class ResolveVisitor : public AstVisitorBase { bool requires_disambiguation; }; - ResolveVisitor(absl::string_view container, - NamespaceGenerator namespace_generator, + ResolveVisitor(NamespaceGenerator namespace_generator, const TypeCheckEnv& env, const Ast& ast, TypeInferenceContext& inference_context, std::vector& issues, google::protobuf::Arena* absl_nonnull arena) - : container_(container), - namespace_generator_(std::move(namespace_generator)), + : namespace_generator_(std::move(namespace_generator)), env_(&env), inference_context_(&inference_context), issues_(&issues), @@ -326,7 +324,7 @@ class ResolveVisitor : public AstVisitorBase { ReportIssue(TypeCheckIssue::CreateError( ast_->ComputeSourceLocation(expr.id()), absl::StrCat("undeclared reference to '", name, "' (in container '", - container_, "')"))); + env_->container().container(), "')"))); } void ReportUndefinedField(int64_t expr_id, absl::string_view field_name, @@ -407,7 +405,6 @@ class ResolveVisitor : public AstVisitorBase { return DynType(); } - absl::string_view container_; NamespaceGenerator namespace_generator_; const TypeCheckEnv* absl_nonnull env_; TypeInferenceContext* absl_nonnull inference_context_; @@ -1260,12 +1257,12 @@ absl::StatusOr TypeCheckerImpl::Check( google::protobuf::Arena type_arena; std::vector issues; - CEL_ASSIGN_OR_RETURN(auto generator, - NamespaceGenerator::Create(env_.container())); + CEL_ASSIGN_OR_RETURN( + auto generator, NamespaceGenerator::Create(env_.container().container())); TypeInferenceContext type_inference_context( &type_arena, options_.enable_legacy_null_assignment); - ResolveVisitor visitor(env_.container(), std::move(generator), env_, *ast, + ResolveVisitor visitor(std::move(generator), env_, *ast, type_inference_context, issues, &type_arena); TraversalOptions opts; diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index 6eccc3701..714e669cd 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -35,6 +35,7 @@ #include "checker/type_check_issue.h" #include "checker/validation_result.h" #include "common/ast.h" +#include "common/container.h" #include "common/decl.h" #include "common/expr.h" #include "common/source.h" @@ -757,7 +758,7 @@ TEST(TypeCheckerImplTest, NestedComprehensions) { TEST(TypeCheckerImplTest, ComprehensionVarsShadowNamespacePriorityRules) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("com"); + env.set_container(*MakeExpressionContainer("com")); google::protobuf::Arena arena; ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); @@ -1462,7 +1463,7 @@ TEST(TypeCheckerImplTest, BadLineOffsets) { TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("google.protobuf"); + env.set_container(*MakeExpressionContainer("google.protobuf")); env.AddTypeProvider(std::make_unique()); TypeCheckerImpl impl(std::move(env)); @@ -1483,7 +1484,7 @@ TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) { TEST(TypeCheckerImplTest, ContainerLookupForMessageCreationNoRewrite) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("google.protobuf"); + env.set_container(*MakeExpressionContainer("google.protobuf")); env.AddTypeProvider(std::make_unique()); CheckerOptions options; @@ -1508,7 +1509,7 @@ TEST(TypeCheckerImplTest, ContainerLookupForMessageCreationNoRewrite) { TEST(TypeCheckerImplTest, EnumValueCopiedToReferenceMap) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("cel.expr.conformance.proto3"); + env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, @@ -1538,7 +1539,7 @@ TEST_P(WktCreationTest, MessageCreation) { const CheckedExprTestCase& test_case = GetParam(); TypeCheckEnv env(GetSharedTestingDescriptorPool()); env.AddTypeProvider(std::make_unique()); - env.set_container("google.protobuf"); + env.set_container(*MakeExpressionContainer("google.protobuf")); ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); @@ -1696,7 +1697,7 @@ TEST_P(GenericMessagesTest, TypeChecksProto3) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("cel.expr.conformance.proto3"); + env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); google::protobuf::LinkMessageReflection(); ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( @@ -2247,7 +2248,7 @@ TEST_P(StrictNullAssignmentTest, TypeChecksProto3) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); - env.set_container("cel.expr.conformance.proto3"); + env.set_container(*MakeExpressionContainer("cel.expr.conformance.proto3")); google::protobuf::LinkMessageReflection(); ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index e5942b157..b3a86f64c 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -25,6 +25,7 @@ #include "absl/strings/string_view.h" #include "checker/checker_options.h" #include "checker/type_checker.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" @@ -132,10 +133,16 @@ class TypeCheckerBuilder { // // This is used for resolving references in the expressions being built. // + // Prefer setting the container via SetExpressionContainer(). + // // Note: if set multiple times, the last value is used. This can lead to - // surprising behavior if used in a custom library. + // surprising behavior if used in a custom library. If container is not a + // valid container name, the operation is ignored. virtual void set_container(absl::string_view container) = 0; + virtual void SetExpressionContainer( + ExpressionContainer expression_container) = 0; + // The current options for the TypeChecker being built. virtual const CheckerOptions& options() const = 0; diff --git a/common/container.cc b/common/container.cc index 4abceea2d..dbfa987d0 100644 --- a/common/container.cc +++ b/common/container.cc @@ -75,7 +75,7 @@ bool ExpressionContainer::AliasListing::IsAbbreviation() const { return IsAbreviation(alias, name); } -absl::StatusOr ExpressionContainer::Create( +absl::StatusOr MakeExpressionContainer( absl::string_view name) { ExpressionContainer container; diff --git a/common/container.h b/common/container.h index a6555a8ac..cd40aaef9 100644 --- a/common/container.h +++ b/common/container.h @@ -51,8 +51,6 @@ class ExpressionContainer { ExpressionContainer() = default; - static absl::StatusOr Create(absl::string_view name); - ExpressionContainer(const ExpressionContainer&) = default; ExpressionContainer(ExpressionContainer&&) = default; ExpressionContainer& operator=(const ExpressionContainer&) = default; @@ -103,14 +101,37 @@ class ExpressionContainer { } private: - explicit ExpressionContainer(absl::string_view name); - std::string container_; // alias -> full name. absl::flat_hash_map aliases_; }; +// Factory function for creating an ExpressionContainer. +absl::StatusOr MakeExpressionContainer( + absl::string_view name); + +// Factory function for creating an ExpressionContainer with a list of +// abbreviations. +template +absl::StatusOr MakeExpressionContainer( + absl::string_view name, Args&&... abbrevs) { + ExpressionContainer container; + absl::Status status = container.SetContainer(name); + if (!status.ok()) { + return status; + } + absl::string_view abbrevs_view[] = {std::forward(abbrevs)...}; + for (absl::string_view abrev : abbrevs_view) { + status.Update(container.AddAbbreviation(abrev)); + if (!status.ok()) { + return status; + } + } + + return container; +} + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_CONTAINER_H_ diff --git a/common/container_test.cc b/common/container_test.cc index d8c052040..991362320 100644 --- a/common/container_test.cc +++ b/common/container_test.cc @@ -33,6 +33,27 @@ TEST(ExpressionContainerTest, DefaultConstructed) { EXPECT_THAT(container.FindAlias("foo"), IsEmpty()); } +TEST(ExpressionContainerTest, MakeExpressionContainer) { + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("my.container")); + EXPECT_THAT(container.container(), Eq("my.container")); + + EXPECT_THAT(MakeExpressionContainer("..invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExpressionContainerTest, MakeExpressionContainerWithAbbrevs) { + ASSERT_OK_AND_ASSIGN( + ExpressionContainer container, + MakeExpressionContainer("my.container", "pkg.Abbr", "qual.pkg.Abbr2")); + EXPECT_THAT(container.container(), Eq("my.container")); + EXPECT_THAT(container.FindAlias("Abbr"), Eq("pkg.Abbr")); + EXPECT_THAT(container.FindAlias("Abbr2"), Eq("qual.pkg.Abbr2")); + + EXPECT_THAT(MakeExpressionContainer("my.container", "invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + TEST(ExpressionContainerTest, SetContainer) { ExpressionContainer container; EXPECT_THAT(container.SetContainer("my.container.name"), IsOk()); @@ -43,21 +64,21 @@ TEST(ExpressionContainerTest, SetContainer) { TEST(ExpressionContainerTest, AddAlias) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, - ExpressionContainer::Create("my.container")); + MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAlias("foo", "bar.baz"), IsOk()); EXPECT_THAT(container.FindAlias("foo"), Eq("bar.baz")); } TEST(ExpressionContainerTest, AddAbbreviation) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, - ExpressionContainer::Create("my.container")); + MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAbbreviation("qual.pkg.TypeName"), IsOk()); EXPECT_THAT(container.FindAlias("TypeName"), Eq("qual.pkg.TypeName")); } TEST(ExpressionContainerTest, ListAbbreviationsAndAliases) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, - ExpressionContainer::Create("my.container")); + MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAbbreviation("qual.pkg.Abbr"), IsOk()); EXPECT_THAT(container.AddAlias("AliasSym", "some.long.name"), IsOk()); @@ -70,7 +91,7 @@ TEST(ExpressionContainerTest, ListAbbreviationsAndAliases) { TEST(ExpressionContainerTest, InvalidAbbreviation) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, - ExpressionContainer::Create("my.container")); + MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAbbreviation(""), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(container.AddAbbreviation("pkg"), @@ -83,7 +104,7 @@ TEST(ExpressionContainerTest, InvalidAbbreviation) { TEST(ExpressionContainerTest, InvalidAlias) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, - ExpressionContainer::Create("my.container")); + MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAlias("", "bar"), StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(container.AddAlias("foo.bar", "baz"), @@ -94,7 +115,7 @@ TEST(ExpressionContainerTest, InvalidAlias) { TEST(ExpressionContainerTest, CollidesWithContainer) { ASSERT_OK_AND_ASSIGN(ExpressionContainer container, - ExpressionContainer::Create("my.container")); + MakeExpressionContainer("my.container")); EXPECT_THAT(container.AddAlias("my", "bar"), StatusIs(absl::StatusCode::kInvalidArgument)); } From 0341704b5c79509317360faec164e034d765fd9b Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 23 Apr 2026 13:01:06 -0700 Subject: [PATCH 49/88] Add support for resolving aliases. PiperOrigin-RevId: 904596936 --- checker/internal/BUILD | 5 + checker/internal/namespace_generator.cc | 109 +++++-- checker/internal/namespace_generator.h | 44 ++- checker/internal/namespace_generator_test.cc | 70 +++-- checker/internal/type_checker_impl.cc | 4 +- checker/internal/type_checker_impl_test.cc | 290 ++++++++++++++++++- common/BUILD | 1 + common/container.cc | 43 +-- common/container.h | 5 +- common/container_test.cc | 2 + 10 files changed, 482 insertions(+), 91 deletions(-) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 73e5c177d..1af48af57 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -89,8 +89,12 @@ cc_library( srcs = ["namespace_generator.cc"], hdrs = ["namespace_generator.h"], deps = [ + "//common:container", "//internal:lexis", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -104,6 +108,7 @@ cc_test( srcs = ["namespace_generator_test.cc"], deps = [ ":namespace_generator", + "//common:container", "//internal:testing", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", diff --git a/checker/internal/namespace_generator.cc b/checker/internal/namespace_generator.cc index e5b2cfa51..7ab7628e4 100644 --- a/checker/internal/namespace_generator.cc +++ b/checker/internal/namespace_generator.cc @@ -20,7 +20,7 @@ #include #include "absl/functional/function_ref.h" -#include "absl/status/status.h" +#include "absl/log/absl_check.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -28,19 +28,20 @@ #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "common/container.h" #include "internal/lexis.h" namespace cel::checker_internal { namespace { -bool FieldSelectInterpretationCandidates( +bool FieldSelectInterpretationCandidatesImpl( absl::string_view prefix, - absl::Span partly_qualified_name, + absl::Span partly_qualified_name, bool prefix_is_alias, absl::FunctionRef callback) { for (int i = 0; i < partly_qualified_name.size(); ++i) { std::string buf; int count = partly_qualified_name.size() - i; - auto end_idx = count - 1; + auto end_idx = count - (prefix_is_alias ? 0 : 1); auto ident = absl::StrJoin(partly_qualified_name.subspan(0, count), "."); absl::string_view candidate = ident; if (absl::StartsWith(candidate, ".")) { @@ -54,28 +55,44 @@ bool FieldSelectInterpretationCandidates( return false; } } + if (prefix_is_alias) { + return callback(prefix, 0); + } return true; } +bool FieldSelectInterpretationCandidates( + absl::string_view prefix, + absl::Span partly_qualified_name, + absl::FunctionRef callback) { + return FieldSelectInterpretationCandidatesImpl( + prefix, partly_qualified_name, /*prefix_is_alias=*/false, callback); +} + +bool FieldSelectInterpretationCandidatesWithAlias( + absl::string_view prefix, + absl::Span partly_qualified_name, + absl::FunctionRef callback) { + return FieldSelectInterpretationCandidatesImpl( + prefix, partly_qualified_name, /*prefix_is_alias=*/true, callback); +} + } // namespace absl::StatusOr NamespaceGenerator::Create( - absl::string_view container) { + const ExpressionContainer& expression_container) { std::vector candidates; + absl::string_view container = expression_container.container(); if (container.empty()) { - return NamespaceGenerator(std::move(candidates)); + return NamespaceGenerator(&expression_container, std::move(candidates)); } - if (absl::StartsWith(container, ".")) { - return absl::InvalidArgumentError("container must not start with a '.'"); - } std::string prefix; for (auto segment : absl::StrSplit(container, '.')) { - if (!internal::LexisIsIdentifier(segment)) { - return absl::InvalidArgumentError( - "container must only contain valid identifier segments"); - } + // Assumes the the ExpressionContainer has already validated the container + // and aliases. + ABSL_DCHECK(internal::LexisIsIdentifier(segment)); if (prefix.empty()) { prefix = segment; } else { @@ -84,31 +101,75 @@ absl::StatusOr NamespaceGenerator::Create( candidates.push_back(prefix); } std::reverse(candidates.begin(), candidates.end()); - return NamespaceGenerator(std::move(candidates)); + return NamespaceGenerator(&expression_container, std::move(candidates)); } void NamespaceGenerator::GenerateCandidates( - absl::string_view unqualified_name, - absl::FunctionRef callback) { - if (absl::StartsWith(unqualified_name, ".")) { - callback(unqualified_name.substr(1)); + absl::string_view simple_name, + absl::FunctionRef callback) const { + // Special case for root-relative names. Aliases still apply first. + bool is_root_relative = absl::StartsWith(simple_name, "."); + if (is_root_relative) { + simple_name = simple_name.substr(1); + } + + // The name is unqualified, but may include a namespace (struct creation). + // This is just a quirk of the parser. + if (auto dot_pos = simple_name.find('.'); + dot_pos != absl::string_view::npos) { + absl::string_view first_segment = simple_name.substr(0, dot_pos); + absl::string_view rest = simple_name.substr(dot_pos + 1); + if (auto resolved_alias = expression_container_->FindAlias(first_segment); + !resolved_alias.empty()) { + callback(absl::StrCat(resolved_alias, ".", rest)); + return; + } + } else { + if (auto resolved_alias = expression_container_->FindAlias(simple_name); + !resolved_alias.empty()) { + callback(resolved_alias); + return; + } + } + + if (is_root_relative) { + callback(simple_name); return; } + for (const auto& prefix : candidates_) { - std::string candidate = absl::StrCat(prefix, ".", unqualified_name); + std::string candidate = absl::StrCat(prefix, ".", simple_name); if (!callback(candidate)) { return; } } - callback(unqualified_name); + callback(simple_name); } void NamespaceGenerator::GenerateCandidates( absl::Span partly_qualified_name, - absl::FunctionRef callback) { - // Special case for explicit root relative name. e.g. '.com.example.Foo' - if (!partly_qualified_name.empty() && - absl::StartsWith(partly_qualified_name[0], ".")) { + absl::FunctionRef callback) const { + if (partly_qualified_name.empty()) { + return; + } + + // Special case for root-relative names. Aliases still apply first. + absl::string_view first_segment = partly_qualified_name[0]; + bool is_root_relative = absl::StartsWith(first_segment, "."); + if (is_root_relative) { + first_segment = first_segment.substr(1); + } + + if (auto resolved_alias = expression_container_->FindAlias(first_segment); + !resolved_alias.empty()) { + FieldSelectInterpretationCandidatesWithAlias( + resolved_alias, partly_qualified_name.subspan(1), callback); + // If the alias matches, we don't check the container even if name + // resolution fails. + return; + } + + if (is_root_relative) { FieldSelectInterpretationCandidates("", partly_qualified_name, callback); return; } diff --git a/checker/internal/namespace_generator.h b/checker/internal/namespace_generator.h index 18c40dbda..61cb1956b 100644 --- a/checker/internal/namespace_generator.h +++ b/checker/internal/namespace_generator.h @@ -19,18 +19,26 @@ #include #include +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" #include "absl/functional/function_ref.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "common/container.h" namespace cel::checker_internal { // Utility class for generating namespace qualified candidates for reference // resolution. +// +// This class is expected to be scoped to a single type checking operation and +// borrows the ExpressionContainer from the TypeCheckEnv. class NamespaceGenerator { public: - static absl::StatusOr Create(absl::string_view container); + static absl::StatusOr Create( + const ExpressionContainer& expression_container + ABSL_ATTRIBUTE_LIFETIME_BOUND); // Copyable and movable. NamespaceGenerator(const NamespaceGenerator&) = default; @@ -51,8 +59,18 @@ class NamespaceGenerator { // and unqualified name foo // // com.google.foo, com.foo, foo - void GenerateCandidates(absl::string_view unqualified_name, - absl::FunctionRef callback); + // + // If aliases are present, they override the normal container resolution. + // + // Example: + // container (com.google) + // alias (foo = com.example) + // unqualified name foo + // + // com.example + void GenerateCandidates( + absl::string_view simple_name, + absl::FunctionRef callback) const; // For a partially qualified name, generate all the qualified candidates in // order of resolution precedence and pass them to the provided callback. The @@ -72,16 +90,30 @@ class NamespaceGenerator { // (com.Foo).bar, // (Foo.bar), // (Foo).bar, + // + // If aliases are present, they override the normal container resolution. + // + // Example: + // container (com.google) + // alias (Foo = com.example.Foo) + // partially qualified name Foo.bar + // + // (com.example.Foo.bar), + // (com.example.Foo).bar, void GenerateCandidates( absl::Span partly_qualified_name, - absl::FunctionRef callback); + absl::FunctionRef callback) const; private: - explicit NamespaceGenerator(std::vector candidates) - : candidates_(std::move(candidates)) {} + explicit NamespaceGenerator( + const ExpressionContainer* absl_nonnull expression_container, + std::vector candidates) + : candidates_(std::move(candidates)), + expression_container_(expression_container) {} // list of prefixes ordered from most qualified to least. std::vector candidates_; + const ExpressionContainer* absl_nonnull expression_container_; }; } // namespace cel::checker_internal diff --git a/checker/internal/namespace_generator_test.cc b/checker/internal/namespace_generator_test.cc index da174748a..ba9bb88a4 100644 --- a/checker/internal/namespace_generator_test.cc +++ b/checker/internal/namespace_generator_test.cc @@ -18,19 +18,20 @@ #include #include -#include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "common/container.h" #include "internal/testing.h" namespace cel::checker_internal { namespace { -using ::absl_testing::StatusIs; +using ::absl_testing::IsOk; using ::testing::ElementsAre; using ::testing::Pair; TEST(NamespaceGeneratorTest, EmptyContainer) { - ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create("")); + ExpressionContainer container; + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector candidates; generator.GenerateCandidates("foo", [&](absl::string_view candidate) { candidates.push_back(std::string(candidate)); @@ -40,8 +41,9 @@ TEST(NamespaceGeneratorTest, EmptyContainer) { } TEST(NamespaceGeneratorTest, MultipleSegments) { - ASSERT_OK_AND_ASSIGN(auto generator, - NamespaceGenerator::Create("com.example")); + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector candidates; generator.GenerateCandidates("foo", [&](absl::string_view candidate) { candidates.push_back(std::string(candidate)); @@ -51,8 +53,9 @@ TEST(NamespaceGeneratorTest, MultipleSegments) { } TEST(NamespaceGeneratorTest, MultipleSegmentsRootNamespace) { - ASSERT_OK_AND_ASSIGN(auto generator, - NamespaceGenerator::Create("com.example")); + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector candidates; generator.GenerateCandidates(".foo", [&](absl::string_view candidate) { candidates.push_back(std::string(candidate)); @@ -61,18 +64,46 @@ TEST(NamespaceGeneratorTest, MultipleSegmentsRootNamespace) { EXPECT_THAT(candidates, ElementsAre("foo")); } -TEST(NamespaceGeneratorTest, InvalidContainers) { - EXPECT_THAT(NamespaceGenerator::Create(".com.example"), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(NamespaceGenerator::Create("com..example"), - StatusIs(absl::StatusCode::kInvalidArgument)); - EXPECT_THAT(NamespaceGenerator::Create("com.$example"), - StatusIs(absl::StatusCode::kInvalidArgument)); +TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretation) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + std::vector qualified_ident = {"foo", "Bar"}; + std::vector> candidates; + generator.GenerateCandidates( + qualified_ident, [&](absl::string_view candidate, int segment_index) { + candidates.push_back(std::pair(std::string(candidate), segment_index)); + return true; + }); + EXPECT_THAT( + candidates, + ElementsAre(Pair("com.example.foo.Bar", 1), Pair("com.example.foo", 0), + Pair("com.foo.Bar", 1), Pair("com.foo", 0), + Pair("foo.Bar", 1), Pair("foo", 0))); } -TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretation) { - ASSERT_OK_AND_ASSIGN(auto generator, - NamespaceGenerator::Create("com.example")); +TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretationAliasMatch) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_THAT(container.AddAlias("foo", "bar.baz"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + std::vector qualified_ident = {"foo", "Bar"}; + std::vector> candidates; + generator.GenerateCandidates( + qualified_ident, [&](absl::string_view candidate, int segment_index) { + candidates.push_back(std::pair(std::string(candidate), segment_index)); + return true; + }); + EXPECT_THAT(candidates, + ElementsAre(Pair("bar.baz.Bar", 1), Pair("bar.baz", 0))); +} + +TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretationAliasNoMatch) { + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_THAT(container.AddAbbreviation("foo.Bar"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); + // No match on the alias (Bar) since it's not the first segment. std::vector qualified_ident = {"foo", "Bar"}; std::vector> candidates; generator.GenerateCandidates( @@ -89,8 +120,9 @@ TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretation) { TEST(NamespaceGeneratorTest, MultipleSegmentsSelectInterpretationRootNamespace) { - ASSERT_OK_AND_ASSIGN(auto generator, - NamespaceGenerator::Create("com.example")); + ExpressionContainer container; + ASSERT_THAT(container.SetContainer("com.example"), IsOk()); + ASSERT_OK_AND_ASSIGN(auto generator, NamespaceGenerator::Create(container)); std::vector qualified_ident = {".foo", "Bar"}; std::vector> candidates; generator.GenerateCandidates( diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index df8f83683..6697cc06a 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -1257,8 +1257,8 @@ absl::StatusOr TypeCheckerImpl::Check( google::protobuf::Arena type_arena; std::vector issues; - CEL_ASSIGN_OR_RETURN( - auto generator, NamespaceGenerator::Create(env_.container().container())); + CEL_ASSIGN_OR_RETURN(auto generator, + NamespaceGenerator::Create(env_.container())); TypeInferenceContext type_inference_context( &type_arena, options_.enable_legacy_null_assignment); diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index 714e669cd..a82dd58ee 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -75,6 +75,7 @@ using AstType = cel::TypeSpec; using Severity = TypeCheckIssue::Severity; namespace testpb3 = ::cel::expr::conformance::proto3; +namespace testpb2 = ::cel::expr::conformance::proto2; std::string SevString(Severity severity) { switch (severity) { @@ -583,6 +584,34 @@ TEST(TypeCheckerImplTest, NamespacedFunctionSkipsFieldCheck) { EXPECT_FALSE(checked_ast->root_expr().call_expr().has_target()); } +TEST(TypeCheckerImplTest, NamespacedFunctionWithAbbreviation) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + // Variables + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + + FunctionDecl foo; + foo.set_name("x.y.foo"); + ASSERT_THAT( + foo.AddOverload(MakeOverloadDecl("x_y_foo_int", + /*return_type=*/IntType(), IntType())), + IsOk()); + env.InsertFunctionIfAbsent(std::move(foo)); + env.set_container(*MakeExpressionContainer("", "x.y.foo")); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo(x)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_TRUE(checked_ast->root_expr().has_call_expr()) + << absl::StrCat("kind: ", checked_ast->root_expr().kind().index()); + EXPECT_EQ(checked_ast->root_expr().call_expr().function(), "x.y.foo"); + EXPECT_FALSE(checked_ast->root_expr().call_expr().has_target()); +} + TEST(TypeCheckerImplTest, MixedListTypeToDyn) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); @@ -1689,10 +1718,257 @@ INSTANTIATE_TEST_SUITE_P( .expected_result_type = AstType(PrimitiveType::kBool), })); +TEST(AliasTest, ImportVariable) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("cel.expr.conformance", + "com.example.TestVariable1", + "com.example.TestVariable2")); + env.set_container(std::move(container)); + + google::protobuf::LinkMessageReflection(); + google::protobuf::LinkMessageReflection(); + + ASSERT_TRUE(env.InsertVariableIfAbsent( + MakeVariableDecl("com.example.TestVariable1", + MessageType(testpb3::TestAllTypes::descriptor())))); + ASSERT_TRUE(env.InsertVariableIfAbsent( + MakeVariableDecl("com.example.TestVariable2", + MessageType(testpb2::TestAllTypes::descriptor())))); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAst( + "TestVariable1.single_int64 == TestVariable2.single_int64")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + ASSERT_TRUE(checked_ast->root_expr().has_call_expr()); + ASSERT_EQ(checked_ast->root_expr().call_expr().function(), "_==_"); + ASSERT_THAT(checked_ast->root_expr().call_expr().args(), SizeIs(2)); + ASSERT_EQ(checked_ast->root_expr() + .call_expr() + .args()[0] + .select_expr() + .operand() + .ident_expr() + .name(), + "com.example.TestVariable1"); + ASSERT_EQ(checked_ast->root_expr() + .call_expr() + .args()[1] + .select_expr() + .operand() + .ident_expr() + .name(), + "com.example.TestVariable2"); +} + +TEST(AliasTest, AliasToContainerResolvesMessage) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ExpressionContainer container; + ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); + + env.set_container(std::move(container)); + + google::protobuf::LinkMessageReflection(); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("pb3.TestAllTypes{single_int64: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT( + checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(AstType(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes")))))); + + EXPECT_THAT( + checked_ast->reference_map(), + Contains(Pair(checked_ast->root_expr().id(), + Property(&Reference::name, + "cel.expr.conformance.proto3.TestAllTypes")))); + + EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), + "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(AliasTest, AliasSimpleName) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ExpressionContainer container; + ASSERT_THAT(container.AddAlias("foo", "bar"), IsOk()); + + env.set_container(std::move(container)); + + google::protobuf::LinkMessageReflection(); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + env.InsertOrReplaceVariable(MakeVariableDecl("bar", IntType())); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_EQ(checked_ast->root_expr().ident_expr().name(), "bar"); +} + +TEST(AliasTest, AliasPreventsContainerResolution) { + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_OK_AND_ASSIGN(ExpressionContainer container, + MakeExpressionContainer("cel.expr")); + ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); + env.set_container(std::move(container)); + + ASSERT_TRUE(env.InsertVariableIfAbsent( + MakeVariableDecl("cel.expr.pb3.FooVariable", IntType()))); + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("FooVariable")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'FooVariable'"))); + } + + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("pb3.FooVariable")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'pb3.FooVariable'"))); + } + + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("expr.pb3.FooVariable")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->root_expr().ident_expr().name(), + "cel.expr.pb3.FooVariable"); + } +} + +TEST(AliasTest, AliasPreventsDisambiguation) { + // Copying behavior from cel-go and cel-java. + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ExpressionContainer container; + ASSERT_THAT(container.AddAlias("pb3", "cel.expr.conformance.proto3"), IsOk()); + env.set_container(std::move(container)); + env.InsertOrReplaceVariable(MakeVariableDecl("pb3.Foo", IntType())); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + + { + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst("pb3.TestAllTypes{single_int64: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), + "cel.expr.conformance.proto3.TestAllTypes"); + } + { + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst(".pb3.TestAllTypes{single_int64: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->root_expr().struct_expr().name(), + "cel.expr.conformance.proto3.TestAllTypes"); + } + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("pb3.Foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to 'pb3.Foo'"))); + } + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(".pb3.Foo")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring( + Severity::kError, "undeclared reference to '.pb3.Foo'"))); + } +} + class GenericMessagesTest : public testing::TestWithParam { }; -TEST_P(GenericMessagesTest, TypeChecksProto3) { +TEST_P(GenericMessagesTest, TypeChecksProto3Imports) { + const CheckedExprTestCase& test_case = GetParam(); + google::protobuf::Arena arena; + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.set_container(*MakeExpressionContainer( + "", "cel.expr.conformance.proto3.TestAllTypes", + "cel.expr.conformance.proto3.NestedTestAllTypes")); + google::protobuf::LinkMessageReflection(); + + ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( + "test_msg", MessageType(testpb3::TestAllTypes::descriptor())))); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + if (!test_case.error_substring.empty()) { + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, + test_case.error_substring))); + return; + } + + ASSERT_TRUE(result.IsValid()) << result.FormatError(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(test_case.expected_result_type)))) + << cel::test::FormatBaselineAst(*checked_ast); +} + +TEST_P(GenericMessagesTest, TypeChecksProto3Container) { const CheckedExprTestCase& test_case = GetParam(); google::protobuf::Arena arena; @@ -1715,11 +1991,7 @@ TEST_P(GenericMessagesTest, TypeChecksProto3) { return; } - ASSERT_TRUE(result.IsValid()) - << absl::StrJoin(result.GetIssues(), "\n", - [](std::string* out, const TypeCheckIssue& issue) { - absl::StrAppend(out, issue.message()); - }); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); @@ -1840,6 +2112,12 @@ INSTANTIATE_TEST_SUITE_P( .expected_result_type = AstType( MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), }, + CheckedExprTestCase{ + .expr = "TestAllTypes{repeated_nested_message: " + "[TestAllTypes.NestedMessage{bb: 42}]}", + .expected_result_type = AstType( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + }, CheckedExprTestCase{ .expr = "TestAllTypes{single_duration: duration('1s')}", .expected_result_type = AstType( diff --git a/common/BUILD b/common/BUILD index ea6246b51..db177951d 100644 --- a/common/BUILD +++ b/common/BUILD @@ -1146,6 +1146,7 @@ cc_library( srcs = ["container.cc"], hdrs = ["container.h"], deps = [ + "//internal:lexis", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/common/container.cc b/common/container.cc index dbfa987d0..f69f0cc80 100644 --- a/common/container.cc +++ b/common/container.cc @@ -14,7 +14,6 @@ #include "common/container.h" -#include #include #include @@ -22,48 +21,28 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "internal/lexis.h" namespace cel { namespace { -// Basic validation for accidental misuse. Does not fully validate against the -// CEL grammar rules for identifiers. -bool IsIdentifierChar(char c) { - return c == '_' || std::isalnum(static_cast(c)); -} - bool IsValidQualifiedName(absl::string_view name) { - bool dot_ok = false; - for (char c : name) { - if (c == '.') { - if (!dot_ok) { - return false; - } - dot_ok = false; - continue; - } - if (!IsIdentifierChar(c)) { + auto dot_pos = name.find('.'); + while (dot_pos != absl::string_view::npos) { + if (!internal::LexisIsIdentifier(name.substr(0, dot_pos))) { return false; } - dot_ok = true; + name = name.substr(dot_pos + 1); + dot_pos = name.find('.'); } - // Must not end in a dot. - return dot_ok; + return internal::LexisIsIdentifier(name); } bool IsValidAlias(absl::string_view alias) { - if (alias.empty()) { - return false; - } - for (char c : alias) { - if (!IsIdentifierChar(c)) { - return false; - } - } - return true; + return internal::LexisIsIdentifier(alias); } -bool IsAbreviation(absl::string_view alias, absl::string_view name) { +bool IsAbbreviationImpl(absl::string_view alias, absl::string_view name) { auto pos = name.rfind('.'); return pos != std::string::npos && pos > 0 && pos < name.size() - 1 && alias == name.substr(pos + 1); @@ -72,7 +51,7 @@ bool IsAbreviation(absl::string_view alias, absl::string_view name) { } // namespace bool ExpressionContainer::AliasListing::IsAbbreviation() const { - return IsAbreviation(alias, name); + return IsAbbreviationImpl(alias, name); } absl::StatusOr MakeExpressionContainer( @@ -170,7 +149,7 @@ absl::string_view ExpressionContainer::FindAlias( std::vector ExpressionContainer::ListAbbreviations() const { std::vector res; for (const auto& entry : aliases_) { - if (IsAbreviation(entry.first, entry.second)) { + if (IsAbbreviationImpl(entry.first, entry.second)) { res.push_back(entry.second); } } diff --git a/common/container.h b/common/container.h index cd40aaef9..ad8d91c35 100644 --- a/common/container.h +++ b/common/container.h @@ -33,8 +33,9 @@ namespace cel { // approximately the same resolution rules as protobuf or C++ namespaces. // // Aliases declare short names that can be referenced without resolving against -// the scopes defined by the container. For consistency, an alias cannot be -// a prefix of the container name. Aliases are always unqualified identifiers. +// the scopes defined by the container. An alias cannot be a prefix of the +// container name, (otherwise re-type-checking an expression could +// change the meaning). Aliases are always unqualified identifiers. // // An abbreviation is a special case of alias that behaves like an import or // using declaration in other languages. (pkg.TypeName -> TypeName). diff --git a/common/container_test.cc b/common/container_test.cc index 991362320..e40814f54 100644 --- a/common/container_test.cc +++ b/common/container_test.cc @@ -60,6 +60,8 @@ TEST(ExpressionContainerTest, SetContainer) { EXPECT_THAT(container.container(), Eq("my.container.name")); EXPECT_THAT(container.SetContainer("..invalid"), StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(container.SetContainer("foo.1invalid"), + StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(ExpressionContainerTest, AddAlias) { From 1ca513feb24061c792b1ed086105df39f4339f5c Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Thu, 23 Apr 2026 23:23:06 -0700 Subject: [PATCH 50/88] Add signature generation functions for types and function overloads PiperOrigin-RevId: 904826666 --- common/BUILD | 1 + common/decl.cc | 187 ++-------------------- common/decl_test.cc | 28 +--- common/internal/BUILD | 31 ++++ common/internal/signature.cc | 211 +++++++++++++++++++++++++ common/internal/signature.h | 61 ++++++++ common/internal/signature_test.cc | 249 ++++++++++++++++++++++++++++++ env/type_info.cc | 6 + env/type_info_test.cc | 4 + 9 files changed, 580 insertions(+), 198 deletions(-) create mode 100644 common/internal/signature.cc create mode 100644 common/internal/signature.h create mode 100644 common/internal/signature_test.cc diff --git a/common/BUILD b/common/BUILD index db177951d..0ead8b15a 100644 --- a/common/BUILD +++ b/common/BUILD @@ -114,6 +114,7 @@ cc_library( ":constant", ":type", ":type_kind", + "//common/internal:signature", "//internal:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", diff --git a/common/decl.cc b/common/decl.cc index 1e06cb703..b338bfd4f 100644 --- a/common/decl.cc +++ b/common/decl.cc @@ -23,8 +23,10 @@ #include "absl/container/flat_hash_set.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "common/internal/signature.h" #include "common/type.h" #include "common/type_kind.h" @@ -104,181 +106,6 @@ bool SignaturesOverlap(const OverloadDecl& lhs, const OverloadDecl& rhs) { return args_overlap; } -void AppendEscaped(std::string* result, absl::string_view str, - bool escape_dot) { - for (char c : str) { - switch (c) { - case '\\': - case '(': - case ')': - case '<': - case '>': - case '"': - case ',': - result->push_back('\\'); - result->push_back(c); - break; - case '.': - if (escape_dot) { - result->push_back('\\'); - } - result->push_back(c); - break; - default: - result->push_back(c); - break; - } - } -} - -void AppendTypeParameters(std::string* result, const Type& type); - -// Recursively appends a string representation of the given `type` to `result`. -// Type parameters are enclosed in angle brackets and separated by commas. -void AppendTypeToOverloadId(std::string* result, const Type& type) { - switch (type.kind()) { - case TypeKind::kNull: - absl::StrAppend(result, "null"); - return; - case TypeKind::kBool: - absl::StrAppend(result, "bool"); - return; - case TypeKind::kInt: - absl::StrAppend(result, "int"); - return; - case TypeKind::kUint: - absl::StrAppend(result, "uint"); - return; - case TypeKind::kDouble: - absl::StrAppend(result, "double"); - return; - case TypeKind::kString: - absl::StrAppend(result, "string"); - return; - case TypeKind::kBytes: - absl::StrAppend(result, "bytes"); - return; - case TypeKind::kDuration: - absl::StrAppend(result, "duration"); - return; - case TypeKind::kTimestamp: - absl::StrAppend(result, "timestamp"); - return; - case TypeKind::kUnknown: - absl::StrAppend(result, "unknown"); - return; - case TypeKind::kError: - absl::StrAppend(result, "error"); - return; - case TypeKind::kAny: - absl::StrAppend(result, "any"); - return; - case TypeKind::kDyn: - absl::StrAppend(result, "dyn"); - return; - case TypeKind::kBoolWrapper: - absl::StrAppend(result, "bool_wrapper"); - return; - case TypeKind::kIntWrapper: - absl::StrAppend(result, "int_wrapper"); - return; - case TypeKind::kUintWrapper: - absl::StrAppend(result, "uint_wrapper"); - return; - case TypeKind::kDoubleWrapper: - absl::StrAppend(result, "double_wrapper"); - return; - case TypeKind::kStringWrapper: - absl::StrAppend(result, "string_wrapper"); - return; - case TypeKind::kBytesWrapper: - absl::StrAppend(result, "bytes_wrapper"); - return; - case TypeKind::kList: - absl::StrAppend(result, "list"); - AppendTypeParameters(result, type); - return; - case TypeKind::kMap: - absl::StrAppend(result, "map"); - AppendTypeParameters(result, type); - return; - case TypeKind::kFunction: - absl::StrAppend(result, "function"); - AppendTypeParameters(result, type); - return; - case TypeKind::kEnum: - absl::StrAppend(result, "enum"); - AppendTypeParameters(result, type); - return; - case TypeKind::kType: - absl::StrAppend(result, "type"); - AppendTypeParameters(result, type); - return; - case TypeKind::kOpaque: - result->push_back('"'); - AppendEscaped(result, type.name(), /*escape_dot=*/false); - result->push_back('"'); - AppendTypeParameters(result, type); - return; - default: // This includes TypeKind::kStruct aka TypeKind::kTypeMessage - AppendEscaped(result, type.name(), /*escape_dot=*/false); - return; - } -} - -void AppendTypeParameters(std::string* result, const Type& type) { - const auto& parameters = type.GetParameters(); - if (!parameters.empty()) { - result->push_back('<'); - for (size_t i = 0; i < parameters.size(); ++i) { - AppendTypeToOverloadId(result, parameters[i]); - if (i < parameters.size() - 1) { - result->push_back(','); - } - } - result->push_back('>'); - } -} - -// Generates an identifier for the overload based on the function name and -// the types of the arguments. If `member` is true, the first argument type -// is used as the receiver and is prepended to the function name, followed by -// a dot. -// -// Examples: -// -// - `foo()` -// - `foo(int)` -// - `bar.foo(int)` -// - `foo(int,string)` -// - `foo(list,list)` -// - `bar.foo(list,list<"my_type">)` -// -std::string GenerateOverloadId(std::string_view function_name, - const std::vector& args, bool member) { - std::string result; - if (member) { - if (!args.empty()) { - AppendTypeToOverloadId(&result, args[0]); - } else { - // This should never happen: a member function with no receiver. - absl::StrAppend(&result, "error"); - } - result.push_back('.'); - } - AppendEscaped(&result, function_name, /*escape_dot=*/true); - result.push_back('('); - for (size_t i = member ? 1 : 0; i < args.size(); ++i) { - AppendTypeToOverloadId(&result, args[i]); - if (i < args.size() - 1) { - result.push_back(','); - } - } - result.push_back(')'); - - return result; -} - template void AddOverloadInternal(std::string_view function_name, std::vector& insertion_order, @@ -290,8 +117,14 @@ void AddOverloadInternal(std::string_view function_name, if (overload.id().empty()) { OverloadDecl overload_decl = overload; - overload_decl.set_id(GenerateOverloadId(function_name, overload_decl.args(), - overload_decl.member())); + absl::StatusOr overload_id = + common_internal::MakeOverloadSignature( + function_name, overload_decl.args(), overload_decl.member()); + if (!overload_id.ok()) { + status = overload_id.status(); + return; + } + overload_decl.set_id(*overload_id); AddOverloadInternal(function_name, insertion_order, overloads, std::move(overload_decl), status); return; diff --git a/common/decl_test.cc b/common/decl_test.cc index 6e5710049..510cd5017 100644 --- a/common/decl_test.cc +++ b/common/decl_test.cc @@ -14,7 +14,10 @@ #include "common/decl.h" -#include "absl/log/die_if_null.h" +#include +#include + +#include "absl/log/die_if_null.h" // IWYU pragma: keep #include "absl/status/status.h" #include "common/constant.h" #include "common/type.h" @@ -186,7 +189,6 @@ TEST(FunctionDecl, OverloadId) { MakeOverloadDecl(IntType{}, TimestampType{}), MakeOverloadDecl(IntType{}, IntWrapperType{}), MakeOverloadDecl(IntType{}, MessageType(descriptor)), - MakeMemberOverloadDecl(IntType{}), MakeMemberOverloadDecl(StringType{}, StringType{}), MakeMemberOverloadDecl(StringType{}, StringType{}, ListType(&arena, BoolType{})), @@ -198,36 +200,20 @@ TEST(FunctionDecl, OverloadId) { ElementsAre(Property(&OverloadDecl::id, "hello()"), Property(&OverloadDecl::id, "hello(string)"), Property(&OverloadDecl::id, "hello(int,uint)"), - Property(&OverloadDecl::id, "hello(list)"), - Property(&OverloadDecl::id, "hello(map)"), - Property(&OverloadDecl::id, "hello(\"bar\">)"), + Property(&OverloadDecl::id, "hello(list<~A>)"), + Property(&OverloadDecl::id, "hello(map<~B,~C>)"), + Property(&OverloadDecl::id, "hello(bar>)"), Property(&OverloadDecl::id, "hello(any)"), Property(&OverloadDecl::id, "hello(duration)"), Property(&OverloadDecl::id, "hello(timestamp)"), Property(&OverloadDecl::id, "hello(int_wrapper)"), Property(&OverloadDecl::id, "hello(cel.expr.conformance.proto3.TestAllTypes)"), - Property(&OverloadDecl::id, "error.hello()"), Property(&OverloadDecl::id, "string.hello()"), Property(&OverloadDecl::id, "string.hello(list)"), Property(&OverloadDecl::id, "string.hello(bool,dyn)"))); } -TEST(FunctionDecl, OverloadIdEscaping) { - google::protobuf::Arena arena; - ASSERT_OK_AND_ASSIGN( - auto function_decl, - MakeFunctionDecl("h.(e),l\\o", - MakeMemberOverloadDecl( - StringType{}, StringType{}, - ListType(&arena, TypeParamType("a,b..(d)\\e"))))); - - EXPECT_THAT(function_decl.overloads(), - ElementsAre(Property(&OverloadDecl::id, - "string.h\\.\\(e\\)\\,l\\\\\\o(list<" - "a\\,b.\\.\\(d\\)\\\\e>)"))); -} - using common_internal::TypeIsAssignable; TEST(TypeIsAssignable, BoolWrapper) { diff --git a/common/internal/BUILD b/common/internal/BUILD index c5ca63564..10084b685 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -137,3 +137,34 @@ cc_library( "@com_google_protobuf//src/google/protobuf/io", ], ) + +cc_library( + name = "signature", + srcs = ["signature.cc"], + hdrs = ["signature.h"], + deps = [ + "//common:type", + "//common:type_kind", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "signature_test", + srcs = ["signature_test.cc"], + deps = [ + ":signature", + "//common:type", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/common/internal/signature.cc b/common/internal/signature.cc new file mode 100644 index 000000000..f63049878 --- /dev/null +++ b/common/internal/signature.cc @@ -0,0 +1,211 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/internal/signature.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/status_macros.h" + +namespace cel::common_internal { + +namespace { + +void AppendEscaped(std::string* result, std::string_view str, bool escape_dot) { + for (char c : str) { + switch (c) { + case '\\': + case '(': + case ')': + case '<': + case '>': + case '"': + case ',': + case '~': + result->push_back('\\'); + break; + case '.': + if (escape_dot) { + result->push_back('\\'); + } + break; + } + result->push_back(c); + } +} + +absl::Status AppendTypeParameters(std::string* result, const Type& type); + +// Recursively appends a string representation of the given `type` to `result`. +// Type parameters are enclosed in angle brackets and separated by commas. + +// Grammar: +// TypeDesc = NamespaceIdentifier [ "<" TypeList ">" ] ; +// NamespaceIdentifier = [ "." ] Identifier { "." Identifier } ; +// TypeList = TypeElem { "," TypeElem } ; +// TypeElem = TypeDesc | TypeParam +// TypeParam = "~" Alpha ; +// Identifier = ( Alpha | "_" ) { AlphaNumeric | "_" } ; +// (* Terminals *) +// Alpha = "a"..."z" | "A"..."Z" ; +// Digit = "0"..."9" ; +// AlphaNumeric = Alpha | Digit ; +// +// For compatibility, the implementation allows unexpected characters in +// type names and parameters and escapes them with a backslash. +absl::Status AppendTypeDesc(std::string* result, const Type& type) { + switch (type.kind()) { + case TypeKind::kNull: + absl::StrAppend(result, "null"); + break; + case TypeKind::kBool: + absl::StrAppend(result, "bool"); + break; + case TypeKind::kInt: + absl::StrAppend(result, "int"); + break; + case TypeKind::kUint: + absl::StrAppend(result, "uint"); + break; + case TypeKind::kDouble: + absl::StrAppend(result, "double"); + break; + case TypeKind::kString: + absl::StrAppend(result, "string"); + break; + case TypeKind::kBytes: + absl::StrAppend(result, "bytes"); + break; + case TypeKind::kDuration: + absl::StrAppend(result, "duration"); + break; + case TypeKind::kTimestamp: + absl::StrAppend(result, "timestamp"); + break; + case TypeKind::kAny: + absl::StrAppend(result, "any"); + break; + case TypeKind::kDyn: + absl::StrAppend(result, "dyn"); + break; + case TypeKind::kBoolWrapper: + absl::StrAppend(result, "bool_wrapper"); + break; + case TypeKind::kIntWrapper: + absl::StrAppend(result, "int_wrapper"); + break; + case TypeKind::kUintWrapper: + absl::StrAppend(result, "uint_wrapper"); + break; + case TypeKind::kDoubleWrapper: + absl::StrAppend(result, "double_wrapper"); + break; + case TypeKind::kStringWrapper: + absl::StrAppend(result, "string_wrapper"); + break; + case TypeKind::kBytesWrapper: + absl::StrAppend(result, "bytes_wrapper"); + break; + case TypeKind::kList: + absl::StrAppend(result, "list"); + CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); + break; + case TypeKind::kMap: + absl::StrAppend(result, "map"); + CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); + break; + case TypeKind::kFunction: + absl::StrAppend(result, "function"); + CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); + break; + case TypeKind::kType: + absl::StrAppend(result, "type"); + CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); + break; + case TypeKind::kTypeParam: + absl::StrAppend(result, "~"); + AppendEscaped(result, type.GetTypeParam().name(), /*escape_dot=*/true); + break; + case TypeKind::kOpaque: + AppendEscaped(result, type.name(), /*escape_dot=*/false); + CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); + break; + case TypeKind::kStruct: + AppendEscaped(result, type.name(), /*escape_dot=*/false); + CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); + break; + default: + return absl::InvalidArgumentError( + absl::StrFormat("Type kind: %s is not supported in CEL declarations", + type.DebugString())); + } + return absl::OkStatus(); +} + +absl::Status AppendTypeParameters(std::string* result, const Type& type) { + const auto& parameters = type.GetParameters(); + if (!parameters.empty()) { + result->push_back('<'); + for (size_t i = 0; i < parameters.size(); ++i) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, parameters[i])); + if (i < parameters.size() - 1) { + result->push_back(','); + } + } + result->push_back('>'); + } + return absl::OkStatus(); +} +} // namespace + +absl::StatusOr MakeTypeSignature(const Type& type) { + std::string result; + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type)); + return result; +} + +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member) { + std::string result; + if (is_member) { + if (!args.empty()) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, args[0])); + } else { + return absl::InvalidArgumentError("Member function with no receiver"); + } + result.push_back('.'); + } + AppendEscaped(&result, function_name, /*escape_dot=*/true); + result.push_back('('); + for (size_t i = is_member ? 1 : 0; i < args.size(); ++i) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, args[i])); + if (i < args.size() - 1) { + result.push_back(','); + } + } + result.push_back(')'); + + return result; +} +} // namespace cel::common_internal diff --git a/common/internal/signature.h b/common/internal/signature.h new file mode 100644 index 000000000..3f31d8fd1 --- /dev/null +++ b/common/internal/signature.h @@ -0,0 +1,61 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "common/type.h" + +namespace cel::common_internal { + +// Generates an signature for a `cel::Type`, which is a string representation of +// the type. +// +// Examples: +// +// - `int` +// - `list` +// - `list>` +absl::StatusOr MakeTypeSignature(const Type& type); + +// Generates an identifier for a function overload based on the function name +// and the types of the arguments. If `is_member` is true, the first argument +// type is used as the receiver and is prepended to the function name, followed +// by a dollar sign. +// +// Examples: +// +// - `foo()` +// - `foo(int)` +// - `bar.foo(int)` +// - `foo(int,string)` +// - `foo(list,list)` +// - `bar.foo(list,list>)` +// +// If the function name contains a period, it is escaped with a backslash, e.g. +// `foo.bar` becomes `foo\.bar`. This allows to disambiguate between a member +// function and qualified target type name. +// +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member); + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ diff --git a/common/internal/signature_test.cc b/common/internal/signature_test.cc new file mode 100644 index 000000000..8e41c70fb --- /dev/null +++ b/common/internal/signature_test.cc @@ -0,0 +1,249 @@ +#include "common/internal/signature.h" +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "common/type.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel::common_internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::testing::HasSubstr; +using ::testing::ValuesIn; + +google::protobuf::Arena* GetTestArena() { + static absl::NoDestructor arena; + return &*arena; +} + +struct TypeSignatureTestCase { + Type type; + std::string expected_signature; + std::string expected_error; +}; + +using TypeSignatureTest = testing::TestWithParam; + +TEST_P(TypeSignatureTest, TypeSignature) { + const auto& param = GetParam(); + + absl::StatusOr signature = + common_internal::MakeTypeSignature(param.type); + if (!param.expected_error.empty()) { + EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + } else { + EXPECT_THAT(signature, IsOkAndHolds(param.expected_signature)); + } +} + +std::vector GetTypeSignatureTestCases() { + return { + { + .type = StringType{}, + .expected_signature = "string", + }, + { + .type = IntType{}, + .expected_signature = "int", + }, + { + .type = ListType(GetTestArena(), StringType{}), + .expected_signature = "list", + }, + { + .type = ListType(GetTestArena(), TypeParamType("A")), + .expected_signature = "list<~A>", + }, + { + .type = MapType(GetTestArena(), IntType{}, DynType{}), + .expected_signature = "map", + }, + { + .type = + MapType(GetTestArena(), TypeParamType("B"), TypeParamType("C")), + .expected_signature = "map<~B,~C>", + }, + { + .type = OpaqueType( + GetTestArena(), "bar", + {FunctionType(GetTestArena(), TypeParamType("D"), {})}), + .expected_signature = "bar>", + }, + { + .type = AnyType{}, + .expected_signature = "any", + }, + { + .type = DurationType{}, + .expected_signature = "duration", + }, + { + .type = TimestampType{}, + .expected_signature = "timestamp", + }, + { + .type = IntWrapperType{}, + .expected_signature = "int_wrapper", + }, + { + .type = MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")), + .expected_signature = "cel.expr.conformance.proto3.TestAllTypes", + }, + { + .type = ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)")), + .expected_signature = R"(list<~a\,b\.\\.\(d\)\\e>)", + }, + { + .type = UnknownType{}, + .expected_error = + "Type kind: *unknown* is not supported in CEL declarations", + }, + { + .type = ErrorType{}, + .expected_error = + "Type kind: *error* is not supported in CEL declarations", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(TypeIdTest, TypeSignatureTest, + ValuesIn(GetTypeSignatureTestCases())); + +struct OverloadSignatureTestCase { + std::string function_name = "hello"; + std::vector args; + bool is_member = false; + std::string expected_signature; + std::string expected_error; +}; + +using OverloadSignatureTest = testing::TestWithParam; + +TEST_P(OverloadSignatureTest, OverloadSignature) { + const auto& param = GetParam(); + + absl::StatusOr signature = + common_internal::MakeOverloadSignature(param.function_name, param.args, + param.is_member); + if (!param.expected_error.empty()) { + EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr(param.expected_error))); + } else { + EXPECT_THAT(signature, IsOkAndHolds(param.expected_signature)); + } +} + +std::vector GetOverloadSignatureTestCases() { + return { + { + .args = {StringType{}}, + .expected_signature = "hello(string)", + }, + { + .args = {IntType{}, UintType{}}, + .expected_signature = "hello(int,uint)", + }, + { + .args = {ListType(GetTestArena(), StringType{})}, + .expected_signature = "hello(list)", + }, + { + .args = {ListType(GetTestArena(), TypeParamType("A"))}, + .expected_signature = "hello(list<~A>)", + }, + { + .args = {MapType(GetTestArena(), IntType{}, DynType{})}, + .expected_signature = "hello(map)", + }, + { + .args = {MapType(GetTestArena(), TypeParamType("B"), + TypeParamType("C"))}, + .expected_signature = "hello(map<~B,~C>)", + }, + { + .args = {OpaqueType( + GetTestArena(), "bar", + {FunctionType(GetTestArena(), TypeParamType("D"), {})})}, + .expected_signature = "hello(bar>)", + }, + { + .args = {AnyType{}}, + .expected_signature = "hello(any)", + }, + { + .args = {DurationType{}}, + .expected_signature = "hello(duration)", + }, + { + .args = {TimestampType{}}, + .expected_signature = "hello(timestamp)", + }, + { + .args = {IntWrapperType{}}, + .expected_signature = "hello(int_wrapper)", + }, + { + .args = {MessageType( + GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes"))}, + .expected_signature = + "hello(cel.expr.conformance.proto3.TestAllTypes)", + }, + {.args = {}, + .is_member = true, + .expected_error = "Member function with no receiver"}, + { + .args = {StringType{}}, + .is_member = true, + .expected_signature = "string.hello()", + }, + { + .args = {StringType{}, ListType(GetTestArena(), BoolType{})}, + .is_member = true, + .expected_signature = "string.hello(list)", + }, + { + .args = {StringType{}, BoolType{}, DynType{}}, + .is_member = true, + .expected_signature = "string.hello(bool,dyn)", + }, + { + .function_name = R"(h.(e),l\o)", + .args = {StringType{}, + ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)"))}, + .is_member = true, + .expected_signature = + R"(string.h\.\(e\)\,l\\\o(list<~a\,b\.\\.\(d\)\\e>))", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(OverloadIdTest, OverloadSignatureTest, + ValuesIn(GetOverloadSignatureTestCases())); + +} // namespace +} // namespace cel::common_internal diff --git a/env/type_info.cc b/env/type_info.cc index ed72a842f..a5b47b6f1 100644 --- a/env/type_info.cc +++ b/env/type_info.cc @@ -59,11 +59,17 @@ std::optional TypeNameToTypeKind(absl::string_view type_name) { {"any", TypeKind::kAny}, {"dyn", TypeKind::kDyn}, {BoolWrapperType::kName, TypeKind::kBoolWrapper}, + {"bool_wrapper", TypeKind::kBoolWrapper}, {IntWrapperType::kName, TypeKind::kIntWrapper}, + {"int_wrapper", TypeKind::kIntWrapper}, {UintWrapperType::kName, TypeKind::kUintWrapper}, + {"uint_wrapper", TypeKind::kUintWrapper}, {DoubleWrapperType::kName, TypeKind::kDoubleWrapper}, + {"double_wrapper", TypeKind::kDoubleWrapper}, {StringWrapperType::kName, TypeKind::kStringWrapper}, + {"string_wrapper", TypeKind::kStringWrapper}, {BytesWrapperType::kName, TypeKind::kBytesWrapper}, + {"bytes_wrapper", TypeKind::kBytesWrapper}, {"type", TypeKind::kType}, }); if (auto it = kTypeNameToTypeKind->find(type_name); diff --git a/env/type_info_test.cc b/env/type_info_test.cc index ca9d0467c..015d8a928 100644 --- a/env/type_info_test.cc +++ b/env/type_info_test.cc @@ -105,6 +105,10 @@ std::vector GetTestCases() { .type_info = {.name = "google.protobuf.DoubleValue"}, .expected_type_pb = "wrapper: DOUBLE", }, + TestCase{ + .type_info = {.name = "double_wrapper"}, + .expected_type_pb = "wrapper: DOUBLE", + }, TestCase{ .type_info = {.name = "type", .params = {Config::TypeInfo{.name = "duration"}}}, From c4a927e63bc4cb9f3ad4b61e37f4456c2d883380 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 24 Apr 2026 14:24:47 -0700 Subject: [PATCH 51/88] Add ToBuilder() functions for parser/checker/compiler Adds support for extending a configured compiler similar to cel-java. PiperOrigin-RevId: 905220235 --- checker/BUILD | 1 + checker/internal/BUILD | 2 + checker/internal/type_check_env.h | 17 ++++- checker/internal/type_checker_builder_impl.cc | 30 +++++--- checker/internal/type_checker_builder_impl.h | 31 +++++--- checker/internal/type_checker_impl.cc | 6 ++ checker/internal/type_checker_impl.h | 3 + checker/internal/type_checker_impl_test.cc | 39 ++++++++++ checker/type_checker.h | 5 ++ checker/type_checker_builder.h | 1 - checker/type_checker_builder_factory_test.cc | 59 +++++++++++++++ compiler/BUILD | 1 + compiler/compiler.h | 13 +++- compiler/compiler_factory.cc | 14 +++- compiler/compiler_factory_test.cc | 26 +++++++ parser/macro_registry.cc | 10 +++ parser/macro_registry.h | 4 + parser/parser.cc | 27 ++++++- parser/parser_interface.h | 3 + parser/parser_test.cc | 75 +++++++++++++++++++ 20 files changed, 335 insertions(+), 32 deletions(-) diff --git a/checker/BUILD b/checker/BUILD index 7b151d6a8..f1e0cef3c 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -123,6 +123,7 @@ cc_test( srcs = ["type_checker_builder_factory_test.cc"], deps = [ ":checker_options", + ":optional", ":standard_library", ":type_checker", ":type_checker_builder", diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 1af48af57..1c560cdb9 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -74,6 +74,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -175,6 +176,7 @@ cc_test( ":type_checker_impl", "//checker:checker_options", "//checker:type_check_issue", + "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", "//common:container", diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 491e4b550..15f8ecc4d 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -23,6 +23,7 @@ #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -100,7 +101,8 @@ class TypeCheckEnv { type_providers_.push_back(proto_type_introspector_); } - // Move-only. + TypeCheckEnv(const TypeCheckEnv&) = default; + TypeCheckEnv& operator=(const TypeCheckEnv&) = default; TypeCheckEnv(TypeCheckEnv&&) = default; TypeCheckEnv& operator=(TypeCheckEnv&&) = default; @@ -193,11 +195,15 @@ class TypeCheckEnv { // Used to keep an arena alive if one was needed to allocate types. // - // The TypeCheckEnv does not otherwise use it. - void set_arena(std::shared_ptr arena) { + // Expected to be called exactly once if at all. + void set_arena(std::shared_ptr arena) { + ABSL_DCHECK(arena_ == nullptr || arena == arena_); arena_ = std::move(arena); } + // Returns the arena if one was set, nullptr otherwise. + std::shared_ptr arena() const { return arena_; } + private: absl::StatusOr> LookupEnumConstant( absl::string_view type, absl::string_view value) const; @@ -205,7 +211,10 @@ class TypeCheckEnv { absl_nonnull std::shared_ptr descriptor_pool_; // If set, an arena was needed to allocate types in the environment. - absl_nullable std::shared_ptr arena_; + // + // The TypeCheckEnv does not otherwise use the arena, though it may be used by + // derived TypeCheckerBuilders. + absl_nullable std::shared_ptr arena_; ExpressionContainer container_; // Used to resolve fields on message types. diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 9ebcb4e34..94a05602e 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -34,6 +34,7 @@ #include "checker/internal/type_checker_impl.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" @@ -342,8 +343,16 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( } absl::StatusOr> TypeCheckerBuilderImpl::Build() { - TypeCheckEnv env(descriptor_pool_); - env.set_container(expression_container_); + TypeCheckEnv env(template_env_); + CEL_RETURN_IF_ERROR(ConfigureTypeCheckEnv(env)); + return std::make_unique(std::move(env), + options_); +} + +absl::Status TypeCheckerBuilderImpl::ConfigureTypeCheckEnv(TypeCheckEnv& env) { + if (expression_container_.has_value()) { + env.set_container(*expression_container_); + } if (expected_type_.has_value()) { env.set_expected_type(*expected_type_); } @@ -377,12 +386,10 @@ absl::StatusOr> TypeCheckerBuilderImpl::Build() { /*subset=*/nullptr, env)); CEL_RETURN_IF_ERROR(ApplyConfig(default_config_, /*subset=*/nullptr, env)); - // A library may have been the first to initialize the arena, so we need to - // set it as the last step. - env.set_arena(arena_); - auto checker = std::make_unique( - std::move(env), options_); - return checker; + if (type_arena_ != nullptr) { + env.set_arena(type_arena_); + } + return absl::OkStatus(); } absl::Status TypeCheckerBuilderImpl::AddLibrary(CheckerLibrary library) { @@ -432,7 +439,7 @@ absl::Status TypeCheckerBuilderImpl::AddOrReplaceVariable( absl::Status TypeCheckerBuilderImpl::AddContextDeclaration( absl::string_view type) { const google::protobuf::Descriptor* desc = - descriptor_pool_->FindMessageTypeByName(type); + template_env_.descriptor_pool()->FindMessageTypeByName(type); if (desc == nullptr) { return absl::NotFoundError( absl::StrCat("context declaration '", type, "' not found")); @@ -479,7 +486,10 @@ void TypeCheckerBuilderImpl::AddTypeProvider( } void TypeCheckerBuilderImpl::set_container(absl::string_view container) { - expression_container_.SetContainer(container).IgnoreError(); + if (!expression_container_.has_value()) { + expression_container_.emplace(); + } + expression_container_->SetContainer(container).IgnoreError(); } void TypeCheckerBuilderImpl::SetExpressionContainer( diff --git a/checker/internal/type_checker_builder_impl.h b/checker/internal/type_checker_builder_impl.h index 7a099040b..646a5d16f 100644 --- a/checker/internal/type_checker_builder_impl.h +++ b/checker/internal/type_checker_builder_impl.h @@ -40,8 +40,6 @@ namespace cel::checker_internal { -class TypeCheckerBuilderImpl; - // Builder for TypeChecker instances. class TypeCheckerBuilderImpl : public TypeCheckerBuilder { public: @@ -51,7 +49,18 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { const CheckerOptions& options) : options_(options), target_config_(&default_config_), - descriptor_pool_(std::move(descriptor_pool)) {} + template_env_(std::move(descriptor_pool)) {} + + // Constructor for building an extended TypeChecker. + explicit TypeCheckerBuilderImpl(const CheckerOptions& options, + const TypeCheckEnv& template_env) + : options_(options), + target_config_(&default_config_), + template_env_(template_env) { + if (auto arena = template_env_.arena(); arena != nullptr) { + type_arena_ = std::move(arena); + } + } // Move only. TypeCheckerBuilderImpl(const TypeCheckerBuilderImpl&) = delete; @@ -83,14 +92,14 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { const CheckerOptions& options() const override { return options_; } google::protobuf::Arena* absl_nonnull arena() override { - if (arena_ == nullptr) { - arena_ = std::make_shared(); + if (type_arena_ == nullptr) { + type_arena_ = std::make_shared(); } - return arena_.get(); + return type_arena_.get(); } const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool() const override { - return descriptor_pool_.get(); + return template_env_.descriptor_pool(); } private: @@ -129,6 +138,8 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { absl::Status ApplyConfig(ConfigRecord config, const TypeCheckerSubset* subset, TypeCheckEnv& env); + absl::Status ConfigureTypeCheckEnv(TypeCheckEnv& env); + CheckerOptions options_; // Default target for configuration changes. Used for direct calls to // AddVariable, AddFunction, etc. @@ -136,12 +147,12 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { // Active target for configuration changes. // This is used to track which library the change is made on behalf of. ConfigRecord* absl_nonnull target_config_; - std::shared_ptr descriptor_pool_; - std::shared_ptr arena_; + TypeCheckEnv template_env_; + std::shared_ptr type_arena_; std::vector libraries_; absl::flat_hash_map subsets_; absl::flat_hash_set library_ids_; - ExpressionContainer expression_container_; + absl::optional expression_container_; absl::optional expected_type_; }; diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 6697cc06a..8f67efbde 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -37,8 +37,10 @@ #include "checker/internal/format_type_name.h" #include "checker/internal/namespace_generator.h" #include "checker/internal/type_check_env.h" +#include "checker/internal/type_checker_builder_impl.h" #include "checker/internal/type_inference_context.h" #include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/ast_rewrite.h" @@ -1326,4 +1328,8 @@ absl::StatusOr TypeCheckerImpl::Check( return ValidationResult(std::move(ast), std::move(issues)); } +std::unique_ptr TypeCheckerImpl::ToBuilder() const { + return std::make_unique(options_, env_); +} + } // namespace cel::checker_internal diff --git a/checker/internal/type_checker_impl.h b/checker/internal/type_checker_impl.h index 1b9062ec1..71683276d 100644 --- a/checker/internal/type_checker_impl.h +++ b/checker/internal/type_checker_impl.h @@ -22,6 +22,7 @@ #include "checker/checker_options.h" #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" +#include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "google/protobuf/arena.h" @@ -44,6 +45,8 @@ class TypeCheckerImpl : public TypeChecker { absl::StatusOr Check( std::unique_ptr ast) const override; + std::unique_ptr ToBuilder() const override; + private: TypeCheckEnv env_; google::protobuf::Arena type_arena_; diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index a82dd58ee..e6cd641d6 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -33,6 +33,7 @@ #include "checker/internal/test_ast_helpers.h" #include "checker/internal/type_check_env.h" #include "checker/type_check_issue.h" +#include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" #include "common/container.h" @@ -1413,6 +1414,44 @@ TEST(TypeCheckerImplTest, ExpectedTypeDoesntMatch) { "expected type 'map(string, string)' but found 'map(string, int)'"))); } +TEST(TypeCheckerImplTest, ToBuilder) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + TypeCheckerImpl impl(std::move(env)); + auto builder = impl.ToBuilder(); + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); + ASSERT_OK_AND_ASSIGN(auto new_checker, builder->Build()); + + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + new_checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + +TEST(TypeCheckerImplTest, ToBuilderPropagatesArena) { + auto arena = std::make_shared(); + + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + env.set_arena(arena); + + Type list_type = ListType(arena.get(), IntType()); + ASSERT_TRUE( + env.InsertVariableIfAbsent(MakeVariableDecl("my_list", list_type))); + + auto base_checker = std::make_unique(std::move(env)); + + std::unique_ptr builder = base_checker->ToBuilder(); + + base_checker.reset(); + arena.reset(); + + ASSERT_OK_AND_ASSIGN(auto derived_checker, builder->Build()); + + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("my_list")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + derived_checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + TEST(TypeCheckerImplTest, BadSourcePosition) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); diff --git a/checker/type_checker.h b/checker/type_checker.h index 993eafb71..e47b7dca6 100644 --- a/checker/type_checker.h +++ b/checker/type_checker.h @@ -23,6 +23,8 @@ namespace cel { +class TypeCheckerBuilder; + // TypeChecker interface. // // Checks references and type agreement for a parsed CEL expression. @@ -43,6 +45,9 @@ class TypeChecker { virtual absl::StatusOr Check( std::unique_ptr ast) const = 0; + // Returns a builder initialized with the configuration of this type checker. + virtual std::unique_ptr ToBuilder() const = 0; + // TODO(uncreated-issue/73): add overload for cref AST. }; diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index b3a86f64c..5dd1f5256 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -35,7 +35,6 @@ namespace cel { class TypeCheckerBuilder; -class TypeCheckerBuilderImpl; // Functional implementation to apply the library features to a // TypeCheckerBuilder. diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc index 030186c83..38430de5f 100644 --- a/checker/type_checker_builder_factory_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -23,6 +23,7 @@ #include "absl/strings/string_view.h" #include "checker/checker_options.h" #include "checker/internal/test_ast_helpers.h" +#include "checker/optional.h" #include "checker/standard_library.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" @@ -743,5 +744,63 @@ TEST(TypeCheckerBuilderTest, AddFunctionNoOverlapWithStdMacroError) { EXPECT_THAT(builder->AddFunction(fn_decl), IsOk()); } +TEST(TypeCheckerBuilderTest, ToBuilderIndependenceAndInheritance) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddVariable(MakeVariableDecl("x", IntType())), IsOk()); + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl("addOne", + MakeOverloadDecl("addOne_int", IntType(), IntType()))); + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker1, builder->Build()); + + // Exercise checker1. + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("addOne(x)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result1, + checker1->Check(std::move(ast))); + EXPECT_TRUE(result1.IsValid()); + } + + // Start new builder via ToBuilder. + auto builder2 = checker1->ToBuilder(); + ASSERT_THAT(builder2->AddVariable(MakeVariableDecl("y", IntType())), IsOk()); + ASSERT_THAT(builder2->AddLibrary(OptionalCheckerLibrary()), IsOk()); + builder2->SetExpectedType(IntType()); + + ASSERT_OK_AND_ASSIGN(auto checker2, builder2->Build()); + + { + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst("optional.of(addOne(x)).orValue(0) + y")); + ASSERT_OK_AND_ASSIGN(ValidationResult result2, + checker2->Check(std::move(ast))); + EXPECT_TRUE(result2.IsValid()); + } + + // Demonstrate checker1 is unmodified and independent (still does not know + // about y). + { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("y")); + ASSERT_OK_AND_ASSIGN(ValidationResult result_y_checker1_again, + checker1->Check(std::move(ast))); + EXPECT_FALSE(result_y_checker1_again.IsValid()); + } + + // Same for optional library functions. + { + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("optional.none().orValue(x)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + checker1->Check(std::move(ast))); + EXPECT_FALSE(result.IsValid()); + } +} + } // namespace } // namespace cel diff --git a/compiler/BUILD b/compiler/BUILD index 44ef4f537..170f1068b 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -66,6 +66,7 @@ cc_test( deps = [ ":compiler", ":compiler_factory", + ":optional", ":standard_library", "//checker:optional", "//checker:standard_library", diff --git a/compiler/compiler.h b/compiler/compiler.h index 6178cf2dc..48fa4e0b1 100644 --- a/compiler/compiler.h +++ b/compiler/compiler.h @@ -99,8 +99,8 @@ struct CompilerOptions { // Interface for CEL CompilerBuilder objects. // -// Builder implementations are thread hostile, but should create -// thread-compatible Compiler instances. +// Builder implementations do not provide any synchronization themselves, +// but create thread-compatible Compiler instances. class CompilerBuilder { public: virtual ~CompilerBuilder() = default; @@ -140,6 +140,15 @@ class Compiler { // Accessor for the underlying validator. virtual const Validator& GetValidator() const = 0; + + // Returns a builder initialized with the configuration of this compiler. + // + // The returned builder is a copy of the validated environment and may + // behave differently than the builder that created this compiler. + // + // The returned builder does not share state with the compiler and may be + // modified independently. + virtual std::unique_ptr ToBuilder() const = 0; }; } // namespace cel diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc index c83633f68..3e9871706 100644 --- a/compiler/compiler_factory.cc +++ b/compiler/compiler_factory.cc @@ -65,6 +65,8 @@ class CompilerImpl : public Compiler { return result; } + std::unique_ptr ToBuilder() const override; + const TypeChecker& GetTypeChecker() const override { return *type_checker_; } const Parser& GetParser() const override { return *parser_; } const Validator& GetValidator() const override { return validator_; } @@ -78,9 +80,11 @@ class CompilerImpl : public Compiler { class CompilerBuilderImpl : public CompilerBuilder { public: CompilerBuilderImpl(std::unique_ptr type_checker_builder, - std::unique_ptr parser_builder) + std::unique_ptr parser_builder, + Validator validator = Validator()) : type_checker_builder_(std::move(type_checker_builder)), - parser_builder_(std::move(parser_builder)) {} + parser_builder_(std::move(parser_builder)), + validator_(std::move(validator)) {} absl::Status AddLibrary(CompilerLibrary library) override { if (!library.id.empty()) { @@ -154,6 +158,12 @@ class CompilerBuilderImpl : public CompilerBuilder { absl::flat_hash_set subsets_; }; +std::unique_ptr CompilerImpl::ToBuilder() const { + auto builder = std::make_unique( + type_checker_->ToBuilder(), parser_->ToBuilder(), validator_); + return builder; +} + } // namespace absl::StatusOr> NewCompilerBuilder( diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc index cfdc68e26..d217e4cc7 100644 --- a/compiler/compiler_factory_test.cc +++ b/compiler/compiler_factory_test.cc @@ -29,6 +29,7 @@ #include "common/source.h" #include "common/type.h" #include "compiler/compiler.h" +#include "compiler/optional.h" #include "compiler/standard_library.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" @@ -364,5 +365,30 @@ TEST(CompilerFactoryTest, FailsIfNullDescriptorPool) { HasSubstr("descriptor_pool must not be null"))); } +TEST(CompilerFactoryTest, ToBuilderWorks) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + + ASSERT_THAT(builder->GetCheckerBuilder().AddVariable( + MakeVariableDecl("a", MapType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + auto derived_builder = compiler->ToBuilder(); + + ASSERT_THAT(derived_builder->AddLibrary(OptionalCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto derived_compiler, derived_builder->Build()); + + ASSERT_OK_AND_ASSIGN( + ValidationResult result, + derived_compiler->Compile("has(a.b) && a.?b.orValue('foo') == 'foo'")); + EXPECT_TRUE(result.IsValid()); +} + } // namespace } // namespace cel diff --git a/parser/macro_registry.cc b/parser/macro_registry.cc index 3fc77f18c..3a816b10e 100644 --- a/parser/macro_registry.cc +++ b/parser/macro_registry.cc @@ -16,6 +16,7 @@ #include #include +#include #include "absl/status/status.h" #include "absl/strings/match.h" @@ -70,6 +71,15 @@ absl::optional MacroRegistry::FindMacro(absl::string_view name, return absl::nullopt; } +std::vector MacroRegistry::ListMacros() const { + std::vector macros; + macros.reserve(macros_.size()); + for (auto it = macros_.begin(); it != macros_.end(); ++it) { + macros.push_back(it->second); + } + return macros; +} + bool MacroRegistry::RegisterMacroImpl(const Macro& macro) { return macros_.insert(std::pair{macro.key(), macro}).second; } diff --git a/parser/macro_registry.h b/parser/macro_registry.h index 51899bade..01a0634ef 100644 --- a/parser/macro_registry.h +++ b/parser/macro_registry.h @@ -16,6 +16,7 @@ #define THIRD_PARTY_CEL_CPP_PARSER_MACRO_REGISTRY_H_ #include +#include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" @@ -44,6 +45,9 @@ class MacroRegistry final { absl::optional FindMacro(absl::string_view name, size_t arg_count, bool receiver_style) const; + // Returns a copy of all registered macros. + std::vector ListMacros() const; + private: bool RegisterMacroImpl(const Macro& macro); diff --git a/parser/parser.cc b/parser/parser.cc index d430e3169..d9f74e712 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -1707,8 +1707,12 @@ absl::StatusOr ParseImpl(const cel::Source& source, class ParserImpl : public cel::Parser { public: explicit ParserImpl(const ParserOptions& options, - cel::MacroRegistry macro_registry) - : options_(options), macro_registry_(std::move(macro_registry)) {} + cel::MacroRegistry macro_registry, + absl::flat_hash_set library_ids) + : options_(options), + macro_registry_(std::move(macro_registry)), + library_ids_(std::move(library_ids)) {} + absl::StatusOr> Parse( const cel::Source& source) const override { CEL_ASSIGN_OR_RETURN(auto parse_result, @@ -1717,9 +1721,12 @@ class ParserImpl : public cel::Parser { std::move(parse_result.source_info)); } + std::unique_ptr ToBuilder() const override; + private: const ParserOptions options_; const cel::MacroRegistry macro_registry_; + absl::flat_hash_set library_ids_; }; class ParserBuilderImpl : public cel::ParserBuilder { @@ -1796,21 +1803,28 @@ class ParserBuilderImpl : public cel::ParserBuilder { macros_.clear(); } + absl::flat_hash_set library_ids(library_ids_); + // Hack to support adding the standard library macros either by option or // with a library configurer. if (!options_.disable_standard_macros && !library_ids_.contains("stdlib")) { CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(Macro::AllMacros())); + library_ids.insert("stdlib"); } if (options_.enable_optional_syntax && !library_ids_.contains("optional")) { CEL_RETURN_IF_ERROR(macro_registry.RegisterMacro(cel::OptMapMacro())); CEL_RETURN_IF_ERROR(macro_registry.RegisterMacro(cel::OptFlatMapMacro())); + library_ids.insert("optional"); } CEL_RETURN_IF_ERROR(macro_registry.RegisterMacros(individual_macros)); - return std::make_unique(options_, std::move(macro_registry)); + return std::make_unique(options_, std::move(macro_registry), + std::move(library_ids)); } private: + friend class ParserImpl; + ParserOptions options_; std::vector macros_; absl::flat_hash_set library_ids_; @@ -1818,6 +1832,13 @@ class ParserBuilderImpl : public cel::ParserBuilder { absl::flat_hash_map library_subsets_; }; +std::unique_ptr ParserImpl::ToBuilder() const { + auto ins = std::make_unique(options_); + ins->library_ids_ = library_ids_; + ins->macros_ = macro_registry_.ListMacros(); + return ins; +} + } // namespace absl::StatusOr Parse(absl::string_view expression, diff --git a/parser/parser_interface.h b/parser/parser_interface.h index 0992385f7..7cc21ff26 100644 --- a/parser/parser_interface.h +++ b/parser/parser_interface.h @@ -83,6 +83,9 @@ class Parser { // Parses the given source into a CEL AST. virtual absl::StatusOr> Parse( const cel::Source& source) const = 0; + + // Returns a builder initialized with the configuration of this parser. + virtual std::unique_ptr ToBuilder() const = 0; }; } // namespace cel diff --git a/parser/parser_test.cc b/parser/parser_test.cc index c96845e67..3659fd8fd 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1973,6 +1973,81 @@ TEST(NewParserBuilderTest, ForwardsOptions) { StatusIs(absl::StatusCode::kInvalidArgument)); } +TEST(NewParserBuilderTest, ToBuilderCopiesConfig) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_optional_syntax = true; + builder->GetOptions().disable_standard_macros = true; + ASSERT_THAT(builder->AddLibrary({"custom_lib", + [](cel::ParserBuilder& b) { + return b.AddMacro(cel::HasMacro()); + }}), + IsOk()); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + auto derived_builder = parser->ToBuilder(); + EXPECT_TRUE(derived_builder->GetOptions().enable_optional_syntax); + + ASSERT_OK_AND_ASSIGN(auto derived_parser, + std::move(*derived_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("a.?b && has(a.b)")); + ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); + EXPECT_FALSE(ast->IsChecked()); +} + +TEST(NewParserBuilderTest, ToBuilderHandlesStdlibAndOptionalByLibrary) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().disable_standard_macros = true; + builder->GetOptions().enable_optional_syntax = false; + + // Abusing the library ids for testing. Real uses should use subsetting. + ASSERT_THAT( + builder->AddLibrary( + {"stdlib", [](cel::ParserBuilder& b) { return absl::OkStatus(); }}), + IsOk()); + ASSERT_THAT( + builder->AddLibrary( + {"optional", [](cel::ParserBuilder& b) { return absl::OkStatus(); }}), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + auto derived_builder = parser->ToBuilder(); + // Should be ignored now. + derived_builder->GetOptions().disable_standard_macros = false; + derived_builder->GetOptions().enable_optional_syntax = true; + + ASSERT_OK_AND_ASSIGN(auto derived_parser, + std::move(*derived_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b)")); + ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); + + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + EXPECT_EQ(w.Print(ast->root_expr()), + "has(\n" + " a^#2:Expr.Ident#.b^#3:Expr.Select#\n" + ")^#1:Expr.Call#"); +} + +TEST(NewParserBuilderTest, ToBuilderPreservesStdlibAndOptionalFromOptions) { + auto builder = cel::NewParserBuilder(); + builder->GetOptions().disable_standard_macros = false; + builder->GetOptions().enable_optional_syntax = true; + + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + auto derived_builder = parser->ToBuilder(); + + ASSERT_OK_AND_ASSIGN(auto derived_parser, + std::move(*derived_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("has(a.b) && [?a]")); + ASSERT_OK_AND_ASSIGN(auto ast, derived_parser->Parse(*source)); + EXPECT_FALSE(ast->IsChecked()); +} + std::string TestName(const testing::TestParamInfo& test_info) { std::string name = absl::StrCat(test_info.index, "-", test_info.param.I); absl::c_replace_if(name, [](char c) { return !absl::ascii_isalnum(c); }, '_'); From f8ed1bdd658435dee2de9dfd7712a57cf5fe07f2 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 24 Apr 2026 15:16:43 -0700 Subject: [PATCH 52/88] Remove invalid 'size' arg, use alignment as size. PiperOrigin-RevId: 905243616 --- internal/new.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/new.cc b/internal/new.cc index 05396e624..31ec82a08 100644 --- a/internal/new.cc +++ b/internal/new.cc @@ -114,7 +114,7 @@ void AlignedDelete(void* ptr, std::align_val_t alignment) noexcept { ::operator delete(ptr, alignment); #else if (static_cast(alignment) <= kDefaultNewAlignment) { - SizedDelete(ptr, size); + ::operator delete(ptr); } else { #if defined(_MSC_VER) _aligned_free(ptr); From 23c9913ef7f8ed8835165fc6c65617e3090f1630 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 28 Apr 2026 16:26:10 -0700 Subject: [PATCH 53/88] Add option to specify an arena and persist the resolved types from type checking. PiperOrigin-RevId: 907236091 --- checker/BUILD | 5 ++++ checker/internal/type_checker_impl.cc | 34 +++++++++++++++++++------ checker/internal/type_checker_impl.h | 4 +-- checker/type_checker.cc | 36 +++++++++++++++++++++++++++ checker/type_checker.h | 15 ++++++++--- checker/validation_result.h | 20 ++++++++++++++- compiler/BUILD | 2 ++ compiler/compiler.h | 12 +++++++-- compiler/compiler_factory.cc | 7 +++--- compiler/compiler_factory_test.cc | 23 +++++++++++++++++ 10 files changed, 139 insertions(+), 19 deletions(-) create mode 100644 checker/type_checker.cc diff --git a/checker/BUILD b/checker/BUILD index f1e0cef3c..27a1eb84e 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -50,7 +50,9 @@ cc_library( ":type_check_issue", "//common:ast", "//common:source", + "//common:type", "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -74,11 +76,14 @@ cc_test( cc_library( name = "type_checker", + srcs = ["type_checker.cc"], hdrs = ["type_checker.h"], deps = [ ":validation_result", "//common:ast", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", ], ) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 8f67efbde..05601fdbb 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -1176,11 +1177,13 @@ class ResolveRewriter : public AstRewriterBase { explicit ResolveRewriter(const ResolveVisitor& visitor, const TypeInferenceContext& inference_context, const CheckerOptions& options, - Ast::ReferenceMap& references, Ast::TypeMap& types) + Ast::ReferenceMap& references, Ast::TypeMap& types, + ValidationResult::TypeMap& resolved_types) : visitor_(visitor), inference_context_(inference_context), reference_map_(references), type_map_(types), + resolved_types_(resolved_types), options_(options) {} bool PostVisitRewrite(Expr& expr) override { bool rewritten = false; @@ -1235,6 +1238,7 @@ class ResolveRewriter : public AstRewriterBase { return rewritten; } type_map_[expr.id()] = *std::move(flattened_type); + resolved_types_[expr.id()] = iter->second; rewritten = true; } @@ -1249,23 +1253,28 @@ class ResolveRewriter : public AstRewriterBase { const TypeInferenceContext& inference_context_; Ast::ReferenceMap& reference_map_; Ast::TypeMap& type_map_; + ValidationResult::TypeMap& resolved_types_; const CheckerOptions& options_; }; } // namespace -absl::StatusOr TypeCheckerImpl::Check( - std::unique_ptr ast) const { - google::protobuf::Arena type_arena; +absl::StatusOr TypeCheckerImpl::CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* arena) const { + std::optional type_arena; + if (arena == nullptr) { + type_arena.emplace(); + arena = &(*type_arena); + } std::vector issues; CEL_ASSIGN_OR_RETURN(auto generator, NamespaceGenerator::Create(env_.container())); TypeInferenceContext type_inference_context( - &type_arena, options_.enable_legacy_null_assignment); + arena, options_.enable_legacy_null_assignment); ResolveVisitor visitor(std::move(generator), env_, *ast, - type_inference_context, issues, &type_arena); + type_inference_context, issues, arena); TraversalOptions opts; opts.use_comprehension_callbacks = true; @@ -1310,9 +1319,10 @@ absl::StatusOr TypeCheckerImpl::Check( // Apply updates as needed. // Happens in a second pass to simplify validating that pointers haven't // been invalidated by other updates. + ValidationResult::TypeMap resolved_types; ResolveRewriter rewriter(visitor, type_inference_context, options_, ast->mutable_reference_map(), - ast->mutable_type_map()); + ast->mutable_type_map(), resolved_types); AstRewrite(ast->mutable_root_expr(), rewriter); CEL_RETURN_IF_ERROR(rewriter.status()); @@ -1325,7 +1335,15 @@ absl::StatusOr TypeCheckerImpl::Check( {cel::ExtensionSpec::Component::kRuntime})); } - return ValidationResult(std::move(ast), std::move(issues)); + auto result = ValidationResult(std::move(ast), std::move(issues)); + if (!type_arena.has_value()) { + // cel::Type values will expire after this function returns when the local + // arena is destructed. Only set the resolved type map if we're using the + // caller's arena. + result.SetResolvedTypeMap(std::move(resolved_types)); + } + + return result; } std::unique_ptr TypeCheckerImpl::ToBuilder() const { diff --git a/checker/internal/type_checker_impl.h b/checker/internal/type_checker_impl.h index 71683276d..9ee9a50d0 100644 --- a/checker/internal/type_checker_impl.h +++ b/checker/internal/type_checker_impl.h @@ -42,8 +42,8 @@ class TypeCheckerImpl : public TypeChecker { TypeCheckerImpl(TypeCheckerImpl&&) = delete; TypeCheckerImpl& operator=(TypeCheckerImpl&&) = delete; - absl::StatusOr Check( - std::unique_ptr ast) const override; + absl::StatusOr CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* arena) const override; std::unique_ptr ToBuilder() const override; diff --git a/checker/type_checker.cc b/checker/type_checker.cc new file mode 100644 index 000000000..6d59e144d --- /dev/null +++ b/checker/type_checker.cc @@ -0,0 +1,36 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/type_checker.h" + +namespace cel { +absl::StatusOr TypeChecker::Check( + std::unique_ptr ast) const { + return CheckImpl(std::move(ast), nullptr); +} + +absl::StatusOr TypeChecker::Check( + std::unique_ptr ast, google::protobuf::Arena* arena) const { + return CheckImpl(std::move(ast), arena); +} + +absl::StatusOr TypeChecker::Check(const Ast& ast) const { + return CheckImpl(std::make_unique(ast), nullptr); +} + +absl::StatusOr TypeChecker::Check( + const Ast& ast, google::protobuf::Arena* arena) const { + return CheckImpl(std::make_unique(ast), arena); +} +} // namespace cel diff --git a/checker/type_checker.h b/checker/type_checker.h index e47b7dca6..edb6cc91f 100644 --- a/checker/type_checker.h +++ b/checker/type_checker.h @@ -16,10 +16,13 @@ #define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_H_ #include +#include +#include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "checker/validation_result.h" #include "common/ast.h" +#include "google/protobuf/arena.h" namespace cel { @@ -42,13 +45,19 @@ class TypeChecker { // A non-ok status is returned if type checking can't reasonably complete // (e.g. if an internal precondition is violated or an extension returns an // error). - virtual absl::StatusOr Check( - std::unique_ptr ast) const = 0; + absl::StatusOr Check(std::unique_ptr ast) const; + absl::StatusOr Check(std::unique_ptr ast, + google::protobuf::Arena* arena) const; + absl::StatusOr Check(const Ast& ast) const; + absl::StatusOr Check(const Ast& ast, + google::protobuf::Arena* arena) const; // Returns a builder initialized with the configuration of this type checker. virtual std::unique_ptr ToBuilder() const = 0; - // TODO(uncreated-issue/73): add overload for cref AST. + private: + virtual absl::StatusOr CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* absl_nullable arena) const = 0; }; } // namespace cel diff --git a/checker/validation_result.h b/checker/validation_result.h index 8c84a84da..f424e7f6f 100644 --- a/checker/validation_result.h +++ b/checker/validation_result.h @@ -15,26 +15,31 @@ #ifndef THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ +#include #include #include #include #include #include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "checker/type_check_issue.h" #include "common/ast.h" #include "common/source.h" +#include "common/type.h" namespace cel { -// ValidationResult holds the result of TypeChecking. +// ValidationResult holds the result of type checking. // // Error states are captured as type check issues where possible. class ValidationResult { public: + using TypeMap = absl::flat_hash_map; + ValidationResult(std::unique_ptr ast, std::vector issues) : ast_(std::move(ast)), issues_(std::move(issues)) {} @@ -71,6 +76,18 @@ class ValidationResult { return std::move(source_); } + // Returns the resolved type map for the AST. + // + // Only populated if the AST was checked with an explicit arena. + // + // The type entries may have storage in the arena or reference type + // information from the type checker that produced the AST. This means the map + // is only valid as long as both the type checker and the arena are valid. + const TypeMap& GetResolvedTypeMap() const { return resolved_type_map_; } + void SetResolvedTypeMap(TypeMap resolved_type_map) { + resolved_type_map_ = std::move(resolved_type_map); + } + // Returns a string representation of the issues in the result suitable for // display. // @@ -89,6 +106,7 @@ class ValidationResult { private: absl_nullable std::unique_ptr ast_; + TypeMap resolved_type_map_; std::vector issues_; absl_nullable std::unique_ptr source_; }; diff --git a/compiler/BUILD b/compiler/BUILD index 170f1068b..50bc1c9fa 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -28,9 +28,11 @@ cc_library( "//parser:options", "//parser:parser_interface", "//validator", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", ], ) diff --git a/compiler/compiler.h b/compiler/compiler.h index 48fa4e0b1..6d07e72c2 100644 --- a/compiler/compiler.h +++ b/compiler/compiler.h @@ -19,6 +19,7 @@ #include #include +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -29,6 +30,7 @@ #include "parser/options.h" #include "parser/parser_interface.h" #include "validator/validator.h" +#include "google/protobuf/arena.h" namespace cel { @@ -126,10 +128,16 @@ class Compiler { virtual ~Compiler() = default; virtual absl::StatusOr Compile( - absl::string_view source, absl::string_view description) const = 0; + absl::string_view source, absl::string_view description, + google::protobuf::Arena* absl_nullable arena) const = 0; absl::StatusOr Compile(absl::string_view source) const { - return Compile(source, ""); + return Compile(source, "", nullptr); + } + + absl::StatusOr Compile( + absl::string_view source, absl::string_view description) const { + return Compile(source, description, nullptr); } // Accessor for the underlying type checker. diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc index 3e9871706..14586825e 100644 --- a/compiler/compiler_factory.cc +++ b/compiler/compiler_factory.cc @@ -33,6 +33,7 @@ #include "parser/parser.h" #include "parser/parser_interface.h" #include "validator/validator.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { @@ -50,13 +51,13 @@ class CompilerImpl : public Compiler { validator_(std::move(validator)) {} absl::StatusOr Compile( - absl::string_view expression, - absl::string_view description) const override { + absl::string_view expression, absl::string_view description, + google::protobuf::Arena* arena) const override { CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expression, std::string(description))); CEL_ASSIGN_OR_RETURN(auto ast, parser_->Parse(*source)); CEL_ASSIGN_OR_RETURN(ValidationResult result, - type_checker_->Check(std::move(ast))); + type_checker_->Check(std::move(ast), arena)); result.SetSource(std::move(source)); if (!validator_.validations().empty()) { diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc index d217e4cc7..214c23765 100644 --- a/compiler/compiler_factory_test.cc +++ b/compiler/compiler_factory_test.cc @@ -37,6 +37,7 @@ #include "parser/parser_interface.h" #include "testutil/baseline_tests.h" #include "validator/timestamp_literal_validator.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { @@ -390,5 +391,27 @@ TEST(CompilerFactoryTest, ToBuilderWorks) { EXPECT_TRUE(result.IsValid()); } +TEST(CompilerFactoryTest, SpecifyArenaKeepsResolvedTypes) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(OptionalCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("[[1, 2, 3]][?0]", "", &arena)); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + auto it = result.GetResolvedTypeMap().find(ast->root_expr().id()); + ASSERT_TRUE(it != result.GetResolvedTypeMap().end()); + EXPECT_TRUE( + it->second.IsOptional() && + it->second.GetOptional().GetParameter().IsList() && + it->second.GetOptional().GetParameter().GetList().GetElement().IsInt()); +} + } // namespace } // namespace cel From 9e73d93f77a159a149edd9a465d092cff03d702a Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 30 Apr 2026 13:04:30 -0700 Subject: [PATCH 54/88] Deprecate option to switch to old accumulator var name in parser. cel::ParserOption::enable_hidden_accumulator_var now has no effect and will be removed in a later change. The standard / extension macros should always use `@result` now. PiperOrigin-RevId: 908333116 --- common/expr.h | 4 +- common/expr_factory.h | 1 - extensions/comprehensions_v2_macros.cc | 72 +++---- parser/macro.cc | 44 ++-- parser/macro_expr_factory.h | 3 +- parser/macro_expr_factory_test.cc | 2 +- parser/options.h | 4 +- parser/parser.cc | 19 +- parser/parser_test.cc | 270 ------------------------- 9 files changed, 72 insertions(+), 347 deletions(-) diff --git a/common/expr.h b/common/expr.h index 9c6f508c6..7305c2c9f 100644 --- a/common/expr.h +++ b/common/expr.h @@ -45,7 +45,9 @@ class MapExprEntry; class MapExpr; class ComprehensionExpr; -inline constexpr absl::string_view kAccumulatorVariableName = "__result__"; +inline constexpr absl::string_view kAccumulatorVariableName = "@result"; +inline constexpr absl::string_view kDeprecatedAccumulatorVariableName = + "__result__"; bool operator==(const Expr& lhs, const Expr& rhs); diff --git a/common/expr_factory.h b/common/expr_factory.h index c8a9b831f..b9769b457 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -357,7 +357,6 @@ class ExprFactory { friend class ParserMacroExprFactory; ExprFactory() : accu_var_(kAccumulatorVariableName) {} - explicit ExprFactory(absl::string_view accu_var) : accu_var_(accu_var) {} std::string accu_var_; }; diff --git a/extensions/comprehensions_v2_macros.cc b/extensions/comprehensions_v2_macros.cc index a8de3a103..134fb80ff 100644 --- a/extensions/comprehensions_v2_macros.cc +++ b/extensions/comprehensions_v2_macros.cc @@ -56,15 +56,15 @@ absl::optional ExpandAllMacro2(MacroExprFactory& factory, Expr& target, args[0], "all() second variable must be different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("all() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("all() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(true); auto condition = @@ -102,15 +102,15 @@ absl::optional ExpandExistsMacro2(MacroExprFactory& factory, Expr& target, args[0], "exists() second variable must be different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("exists() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(false); auto condition = factory.NewCall( @@ -153,15 +153,15 @@ absl::optional ExpandExistsOneMacro2(MacroExprFactory& factory, "existsOne() second variable must be different " "from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("existsOne() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("existsOne() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewIntConst(0); auto condition = factory.NewBoolConst(true); @@ -205,15 +205,15 @@ absl::optional ExpandTransformList3Macro(MacroExprFactory& factory, "transformList() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformList() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformList() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -254,15 +254,15 @@ absl::optional ExpandTransformList4Macro(MacroExprFactory& factory, "transformList() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformList() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformList() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -305,15 +305,15 @@ absl::optional ExpandTransformMap3Macro(MacroExprFactory& factory, "transformMap() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMap() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMap() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -353,15 +353,15 @@ absl::optional ExpandTransformMap4Macro(MacroExprFactory& factory, "transformMap() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMap() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMap() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -403,17 +403,17 @@ absl::optional ExpandTransformMapEntry3Macro(MacroExprFactory& factory, "transformMapEntry() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMapEntry() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMapEntry() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -453,17 +453,17 @@ absl::optional ExpandTransformMapEntry4Macro(MacroExprFactory& factory, "transformMapEntry() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMapEntry() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMapEntry() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); diff --git a/parser/macro.cc b/parser/macro.cc index eaa1ebd1a..8f8c9e596 100644 --- a/parser/macro.cc +++ b/parser/macro.cc @@ -91,10 +91,10 @@ absl::optional ExpandAllMacro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "all() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { - return factory.ReportErrorAt(args[1], - absl::StrCat("all() variable name cannot be ", - kAccumulatorVariableName)); + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("all() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(true); auto condition = @@ -123,10 +123,10 @@ absl::optional ExpandExistsMacro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "exists() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(false); auto condition = factory.NewCall( @@ -157,10 +157,10 @@ absl::optional ExpandExistsOneMacro(MacroExprFactory& factory, return factory.ReportErrorAt( args[0], "exists_one() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists_one() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewIntConst(0); auto condition = factory.NewBoolConst(true); @@ -196,10 +196,10 @@ absl::optional ExpandMap2Macro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { - return factory.ReportErrorAt(args[1], - absl::StrCat("map() variable name cannot be ", - kAccumulatorVariableName)); + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("map() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewList(); auto condition = factory.NewBoolConst(true); @@ -229,10 +229,10 @@ absl::optional ExpandMap3Macro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { - return factory.ReportErrorAt(args[1], - absl::StrCat("map() variable name cannot be ", - kAccumulatorVariableName)); + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("map() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewList(); auto condition = factory.NewBoolConst(true); @@ -264,10 +264,10 @@ absl::optional ExpandFilterMacro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "filter() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("filter() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto name = args[0].ident_expr().name(); @@ -302,10 +302,10 @@ absl::optional ExpandOptMapMacro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "optMap() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("optMap() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto var_name = args[0].ident_expr().name(); @@ -341,10 +341,10 @@ absl::optional ExpandOptFlatMapMacro(MacroExprFactory& factory, return factory.ReportErrorAt( args[0], "optFlatMap() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("optFlatMap() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto var_name = args[0].ident_expr().name(); diff --git a/parser/macro_expr_factory.h b/parser/macro_expr_factory.h index ffba5e2f2..c66aa4fe0 100644 --- a/parser/macro_expr_factory.h +++ b/parser/macro_expr_factory.h @@ -319,8 +319,7 @@ class MacroExprFactory : protected ExprFactory { friend class ParserMacroExprFactory; friend class TestMacroExprFactory; - explicit MacroExprFactory(absl::string_view accu_var) - : ExprFactory(accu_var) {} + explicit MacroExprFactory() = default; }; } // namespace cel diff --git a/parser/macro_expr_factory_test.cc b/parser/macro_expr_factory_test.cc index 04705eec6..489538be1 100644 --- a/parser/macro_expr_factory_test.cc +++ b/parser/macro_expr_factory_test.cc @@ -27,7 +27,7 @@ namespace cel { class TestMacroExprFactory final : public MacroExprFactory { public: - TestMacroExprFactory() : MacroExprFactory(kAccumulatorVariableName) {} + TestMacroExprFactory() = default; ExprId id() const { return id_; } diff --git a/parser/options.h b/parser/options.h index a41d16104..916a941f0 100644 --- a/parser/options.h +++ b/parser/options.h @@ -51,7 +51,9 @@ struct ParserOptions final { // Disable standard macros (has, all, exists, exists_one, filter, map). bool disable_standard_macros = false; - // Enable hidden accumulator variable '@result' for builtin comprehensions. + // Deprecated: The builtin and extension macros now always use the new + // accumulator variable name. + // This option has no effect. bool enable_hidden_accumulator_var = true; // Enables support for identifier quoting syntax: diff --git a/parser/parser.cc b/parser/parser.cc index d9f74e712..f4ee3a1c5 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -163,9 +163,8 @@ SourceRange SourceRangeFromParserRuleContext( class ParserMacroExprFactory final : public MacroExprFactory { public: - explicit ParserMacroExprFactory(const cel::Source& source, - absl::string_view accu_var) - : MacroExprFactory(accu_var), source_(source) {} + explicit ParserMacroExprFactory(const cel::Source& source) + : source_(source) {} void BeginMacro(SourceRange macro_position) { macro_position_ = macro_position; @@ -607,13 +606,12 @@ class ParserVisitor final : public CelBaseVisitor, public antlr4::BaseErrorListener { public: ParserVisitor(const cel::Source& source, int max_recursion_depth, - absl::string_view accu_var, const cel::MacroRegistry& macro_registry, bool add_macro_calls = false, bool enable_optional_syntax = false, bool enable_quoted_identifiers = false) : source_(source), - factory_(source_, accu_var), + factory_(source_), macro_registry_(macro_registry), recursion_depth_(0), max_recursion_depth_(max_recursion_depth), @@ -1654,14 +1652,9 @@ absl::StatusOr ParseImpl(const cel::Source& source, CommonTokenStream tokens(&lexer); CelParser parser(&tokens); ExprRecursionListener listener(options.max_recursion_depth); - absl::string_view accu_var = cel::kAccumulatorVariableName; - if (options.enable_hidden_accumulator_var) { - accu_var = cel::kHiddenAccumulatorVariableName; - } - ParserVisitor visitor(source, options.max_recursion_depth, accu_var, - registry, options.add_macro_calls, - options.enable_optional_syntax, - options.enable_quoted_identifiers); + ParserVisitor visitor( + source, options.max_recursion_depth, registry, options.add_macro_calls, + options.enable_optional_syntax, options.enable_quoted_identifiers); lexer.removeErrorListeners(); parser.removeErrorListeners(); diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 3659fd8fd..a1a65481d 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1473,7 +1473,6 @@ class ExpressionTest : public testing::TestWithParam {}; TEST_P(ExpressionTest, Parse) { const TestInfo& test_info = GetParam(); ParserOptions options; - options.enable_hidden_accumulator_var = true; if (!test_info.M.empty()) { options.add_macro_calls = true; } @@ -1628,271 +1627,6 @@ TEST(ExpressionTest, RecursionDepthIgnoresParentheses) { EXPECT_THAT(result, IsOk()); } -const std::vector& UpdatedAccuVarTestCases() { - static const std::vector* kInstance = new std::vector{ - {"[].exists(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " false^#7:bool#,\n" - " // LoopCondition\n" - " @not_strictly_false(\n" - " !_(\n" - " __result__^#8:Expr.Ident#\n" - " )^#9:Expr.Call#\n" - " )^#10:Expr.Call#,\n" - " // LoopStep\n" - " _||_(\n" - " __result__^#11:Expr.Ident#,\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#\n" - " )^#12:Expr.Call#,\n" - " // Result\n" - " __result__^#13:Expr.Ident#)^#14:Expr.Comprehension#"}, - {"[].exists_one(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " 0^#7:int64#,\n" - " // LoopCondition\n" - " true^#8:bool#,\n" - " // LoopStep\n" - " _?_:_(\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#,\n" - " _+_(\n" - " __result__^#9:Expr.Ident#,\n" - " 1^#10:int64#\n" - " )^#11:Expr.Call#,\n" - " __result__^#12:Expr.Ident#\n" - " )^#13:Expr.Call#,\n" - " // Result\n" - " _==_(\n" - " __result__^#14:Expr.Ident#,\n" - " 1^#15:int64#\n" - " )^#16:Expr.Call#)^#17:Expr.Comprehension#"}, - {"[].all(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " true^#7:bool#,\n" - " // LoopCondition\n" - " @not_strictly_false(\n" - " __result__^#8:Expr.Ident#\n" - " )^#9:Expr.Call#,\n" - " // LoopStep\n" - " _&&_(\n" - " __result__^#10:Expr.Ident#,\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#\n" - " )^#11:Expr.Call#,\n" - " // Result\n" - " __result__^#12:Expr.Ident#)^#13:Expr.Comprehension#"}, - {"[].map(x, x + 1)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " []^#7:Expr.CreateList#,\n" - " // LoopCondition\n" - " true^#8:bool#,\n" - " // LoopStep\n" - " _+_(\n" - " __result__^#9:Expr.Ident#,\n" - " [\n" - " _+_(\n" - " x^#4:Expr.Ident#,\n" - " 1^#6:int64#\n" - " )^#5:Expr.Call#\n" - " ]^#10:Expr.CreateList#\n" - " )^#11:Expr.Call#,\n" - " // Result\n" - " __result__^#12:Expr.Ident#)^#13:Expr.Comprehension#"}, - {"[].map(x, x > 0, x + 1)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " []^#10:Expr.CreateList#,\n" - " // LoopCondition\n" - " true^#11:bool#,\n" - " // LoopStep\n" - " _?_:_(\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#,\n" - " _+_(\n" - " __result__^#12:Expr.Ident#,\n" - " [\n" - " _+_(\n" - " x^#7:Expr.Ident#,\n" - " 1^#9:int64#\n" - " )^#8:Expr.Call#\n" - " ]^#13:Expr.CreateList#\n" - " )^#14:Expr.Call#,\n" - " __result__^#15:Expr.Ident#\n" - " )^#16:Expr.Call#,\n" - " // Result\n" - " __result__^#17:Expr.Ident#)^#18:Expr.Comprehension#"}, - {"[].filter(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " []^#7:Expr.CreateList#,\n" - " // LoopCondition\n" - " true^#8:bool#,\n" - " // LoopStep\n" - " _?_:_(\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#,\n" - " _+_(\n" - " __result__^#9:Expr.Ident#,\n" - " [\n" - " x^#3:Expr.Ident#\n" - " ]^#10:Expr.CreateList#\n" - " )^#11:Expr.Call#,\n" - " __result__^#12:Expr.Ident#\n" - " )^#13:Expr.Call#,\n" - " // Result\n" - " __result__^#14:Expr.Ident#)^#15:Expr.Comprehension#"}, - // Maintain restriction on '__result__' variable name until the default is - // changed everywhere. - { - "[].map(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:20: map() variable name cannot be __result__\n" - " | [].map(__result__, true)\n" - " | ...................^", - }, - { - "[].map(__result__, true, false)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:20: map() variable name cannot be __result__\n" - " | [].map(__result__, true, false)\n" - " | ...................^", - }, - { - "[].filter(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:23: filter() variable name cannot be __result__\n" - " | [].filter(__result__, true)\n" - " | ......................^", - }, - { - "[].exists(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:23: exists() variable name cannot be __result__\n" - " | [].exists(__result__, true)\n" - " | ......................^", - }, - { - "[].all(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:20: all() variable name cannot be __result__\n" - " | [].all(__result__, true)\n" - " | ...................^", - }, - { - "[].exists_one(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:27: exists_one() variable name cannot be " - "__result__\n" - " | [].exists_one(__result__, true)\n" - " | ..........................^", - }}; - return *kInstance; -} - -class UpdatedAccuVarDisabledTest : public testing::TestWithParam {}; - -TEST_P(UpdatedAccuVarDisabledTest, Parse) { - const TestInfo& test_info = GetParam(); - ParserOptions options; - options.enable_hidden_accumulator_var = false; - if (!test_info.M.empty()) { - options.add_macro_calls = true; - } - - auto result = - EnrichedParse(test_info.I, Macro::AllMacros(), "", options); - if (test_info.E.empty()) { - EXPECT_THAT(result, IsOk()); - } else { - EXPECT_THAT(result, Not(IsOk())); - EXPECT_EQ(test_info.E, result.status().message()); - } - - if (!test_info.P.empty()) { - KindAndIdAdorner kind_and_id_adorner; - ExprPrinter w(kind_and_id_adorner); - std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.P, adorned_string) - << result->parsed_expr().ShortDebugString(); - } - - if (!test_info.L.empty()) { - LocationAdorner location_adorner(result->parsed_expr().source_info()); - ExprPrinter w(location_adorner); - std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.L, adorned_string) - << result->parsed_expr().ShortDebugString(); - } - - if (!test_info.R.empty()) { - EXPECT_EQ(test_info.R, ConvertEnrichedSourceInfoToString( - result->enriched_source_info())); - } - - if (!test_info.M.empty()) { - EXPECT_EQ(test_info.M, ConvertMacroCallsToString( - result.value().parsed_expr().source_info())) - << result->parsed_expr().ShortDebugString(); - } -} - TEST(NewParserBuilderTest, Defaults) { auto builder = cel::NewParserBuilder(); ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); @@ -2058,9 +1792,5 @@ std::string TestName(const testing::TestParamInfo& test_info) { INSTANTIATE_TEST_SUITE_P(CelParserTest, ExpressionTest, testing::ValuesIn(test_cases), TestName); -INSTANTIATE_TEST_SUITE_P(UpdatedAccuVarTest, UpdatedAccuVarDisabledTest, - testing::ValuesIn(UpdatedAccuVarTestCases()), - TestName); - } // namespace } // namespace google::api::expr::parser From 1cf21eec91baa4181b481ff4bf6e25b9b5e9afe9 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 4 May 2026 10:35:24 -0700 Subject: [PATCH 55/88] Optionally track structured error messages in parser. Add a parse overload that reports issues with location to an output parameter where reasonable. This is used make the error handling more consistent when using bundled parse + typecheck. PiperOrigin-RevId: 910112013 --- compiler/BUILD | 2 + compiler/compiler.h | 2 + compiler/compiler_factory.cc | 49 +++++++++++++++----- compiler/compiler_factory_test.cc | 14 ++++++ extensions/math_ext_test.cc | 38 +++++----------- parser/BUILD | 2 + parser/parser.cc | 74 +++++++++++++++++++++---------- parser/parser_interface.h | 50 ++++++++++++++++++++- parser/parser_test.cc | 19 ++++++++ 9 files changed, 186 insertions(+), 64 deletions(-) diff --git a/compiler/BUILD b/compiler/BUILD index 50bc1c9fa..d4a0ab4ac 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -42,10 +42,12 @@ cc_library( hdrs = ["compiler_factory.h"], deps = [ ":compiler", + "//checker:type_check_issue", "//checker:type_checker", "//checker:type_checker_builder", "//checker:type_checker_builder_factory", "//checker:validation_result", + "//common:ast", "//common:source", "//internal:noop_delete", "//internal:status_macros", diff --git a/compiler/compiler.h b/compiler/compiler.h index 6d07e72c2..27237df60 100644 --- a/compiler/compiler.h +++ b/compiler/compiler.h @@ -97,6 +97,8 @@ struct CompilerLibrarySubset { struct CompilerOptions { ParserOptions parser_options; CheckerOptions checker_options; + // If true, parse errors will be adapted to issues where possible. + bool adapt_parser_errors = false; }; // Interface for CEL CompilerBuilder objects. diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc index 14586825e..ed22c5630 100644 --- a/compiler/compiler_factory.cc +++ b/compiler/compiler_factory.cc @@ -17,16 +17,19 @@ #include #include #include +#include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/type_checker_builder_factory.h" #include "checker/validation_result.h" +#include "common/ast.h" #include "common/source.h" #include "compiler/compiler.h" #include "internal/status_macros.h" @@ -45,19 +48,38 @@ class CompilerImpl : public Compiler { CompilerImpl(std::unique_ptr type_checker, std::unique_ptr parser, // Copy the validator in case builder is reused. - Validator validator) + Validator validator, CompilerOptions options) : type_checker_(std::move(type_checker)), parser_(std::move(parser)), - validator_(std::move(validator)) {} + validator_(std::move(validator)), + options_(options) {} absl::StatusOr Compile( absl::string_view expression, absl::string_view description, google::protobuf::Arena* arena) const override { CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expression, std::string(description))); - CEL_ASSIGN_OR_RETURN(auto ast, parser_->Parse(*source)); + std::vector parse_issues; + absl::StatusOr> ast = + parser_->Parse(*source, &parse_issues); + if (!ast.ok()) { + if (!options_.adapt_parser_errors || + ast.status().code() != absl::StatusCode::kInvalidArgument || + parse_issues.empty()) { + return ast.status(); + } + std::vector check_issues; + check_issues.reserve(parse_issues.size()); + for (const auto& issue : parse_issues) { + check_issues.push_back(TypeCheckIssue::CreateError( + issue.location(), std::string(issue.message()))); + } + ValidationResult result(std::move(check_issues)); + result.SetSource(std::move(source)); + return result; + } CEL_ASSIGN_OR_RETURN(ValidationResult result, - type_checker_->Check(std::move(ast), arena)); + type_checker_->Check(*std::move(ast), arena)); result.SetSource(std::move(source)); if (!validator_.validations().empty()) { @@ -76,16 +98,18 @@ class CompilerImpl : public Compiler { std::unique_ptr type_checker_; std::unique_ptr parser_; Validator validator_; + CompilerOptions options_; }; class CompilerBuilderImpl : public CompilerBuilder { public: CompilerBuilderImpl(std::unique_ptr type_checker_builder, std::unique_ptr parser_builder, - Validator validator = Validator()) + Validator validator, CompilerOptions options) : type_checker_builder_(std::move(type_checker_builder)), parser_builder_(std::move(parser_builder)), - validator_(std::move(validator)) {} + validator_(std::move(validator)), + options_(options) {} absl::Status AddLibrary(CompilerLibrary library) override { if (!library.id.empty()) { @@ -146,23 +170,23 @@ class CompilerBuilderImpl : public CompilerBuilder { absl::StatusOr> Build() override { CEL_ASSIGN_OR_RETURN(auto parser, parser_builder_->Build()); CEL_ASSIGN_OR_RETURN(auto type_checker, type_checker_builder_->Build()); - return std::make_unique(std::move(type_checker), - std::move(parser), validator_); + return std::make_unique( + std::move(type_checker), std::move(parser), validator_, options_); } private: std::unique_ptr type_checker_builder_; std::unique_ptr parser_builder_; Validator validator_; + CompilerOptions options_; absl::flat_hash_set library_ids_; absl::flat_hash_set subsets_; }; std::unique_ptr CompilerImpl::ToBuilder() const { - auto builder = std::make_unique( - type_checker_->ToBuilder(), parser_->ToBuilder(), validator_); - return builder; + return std::make_unique( + type_checker_->ToBuilder(), parser_->ToBuilder(), validator_, options_); } } // namespace @@ -179,7 +203,8 @@ absl::StatusOr> NewCompilerBuilder( auto parser_builder = NewParserBuilder(options.parser_options); return std::make_unique(std::move(type_checker_builder), - std::move(parser_builder)); + std::move(parser_builder), + Validator(), options); } } // namespace cel diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc index 214c23765..035fd8aa6 100644 --- a/compiler/compiler_factory_test.cc +++ b/compiler/compiler_factory_test.cc @@ -413,5 +413,19 @@ TEST(CompilerFactoryTest, SpecifyArenaKeepsResolvedTypes) { it->second.GetOptional().GetParameter().GetList().GetElement().IsInt()); } +TEST(CompilerFactoryTest, ReturnsIssuesFromParser) { + CompilerOptions opts; + opts.adapt_parser_errors = true; + ASSERT_OK_AND_ASSIGN( + auto builder, NewCompilerBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("a +")); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), testing::Not(testing::IsEmpty())); +} + } // namespace } // namespace cel diff --git a/extensions/math_ext_test.cc b/extensions/math_ext_test.cc index 3088e6fa8..72605648f 100644 --- a/extensions/math_ext_test.cc +++ b/extensions/math_ext_test.cc @@ -23,7 +23,6 @@ #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -110,19 +109,6 @@ struct MacroTestCase { absl::string_view err = ""; }; -std::string FormatIssues(const cel::ValidationResult& result) { - std::string issues; - for (const auto& issue : result.GetIssues()) { - if (!issues.empty()) { - absl::StrAppend(&issues, "\n", - issue.ToDisplayString(*result.GetSource())); - } else { - issues = issue.ToDisplayString(*result.GetSource()); - } - } - return issues; -} - class TestFunction : public CelFunction { public: explicit TestFunction(absl::string_view name) @@ -352,10 +338,11 @@ TEST_P(MathExtMacroParamsTest, ParserTests) { TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { const MacroTestCase& test_case = GetParam(); - - ASSERT_OK_AND_ASSIGN( - auto compiler_builder, - cel::NewCompilerBuilder(internal::GetTestingDescriptorPool())); + CompilerOptions compile_opts; + compile_opts.adapt_parser_errors = true; + ASSERT_OK_AND_ASSIGN(auto compiler_builder, + cel::NewCompilerBuilder( + internal::GetTestingDescriptorPool(), compile_opts)); ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->AddLibrary(MathCompilerLibrary()), IsOk()); @@ -381,16 +368,16 @@ TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*compiler_builder).Build()); - auto result = compiler->Compile(test_case.expr, ""); + ASSERT_OK_AND_ASSIGN(auto result, + compiler->Compile(test_case.expr, "")); if (!test_case.err.empty()) { - EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr(test_case.err))); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.err)); return; } - ASSERT_THAT(result, IsOk()); - ASSERT_TRUE(result->IsValid()) << FormatIssues(*result); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); RuntimeOptions opts; ASSERT_OK_AND_ASSIGN( @@ -411,9 +398,8 @@ TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); - - ASSERT_OK_AND_ASSIGN(auto program, - runtime->CreateProgram(*result->ReleaseAst())); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); google::protobuf::Arena arena; cel::Activation activation; diff --git a/parser/BUILD b/parser/BUILD index 63813bb59..6650d9fe9 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -244,9 +244,11 @@ cc_library( ":options", "//common:ast", "//common:source", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) diff --git a/parser/parser.cc b/parser/parser.cc index f4ee3a1c5..709e2fd41 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -112,13 +112,12 @@ struct ParserError { }; std::string DisplayParserError(const cel::Source& source, - const ParserError& error) { - auto location = - source.GetLocation(error.range.begin).value_or(SourceLocation{}); + SourceLocation location, + absl::string_view message) { return absl::StrCat(absl::StrFormat("ERROR: %s:%zu:%zu: %s", source.description(), location.line, // add one to the 0-based column - location.column + 1, error.message), + location.column + 1, message), source.DisplayErrorLocation(location)); } @@ -209,7 +208,7 @@ class ParserMacroExprFactory final : public MacroExprFactory { bool HasErrors() const { return error_count_ != 0; } - std::string ErrorMessage() { + std::vector CollectIssues() { // Errors are collected as they are encountered, not by their location // within the source. To have a more stable error message as implementation // details change, we sort the collected errors by their source location @@ -226,20 +225,23 @@ class ParserMacroExprFactory final : public MacroExprFactory { }); // Build the summary error message using the sorted errors. bool errors_truncated = error_count_ > 100; - std::vector messages; - messages.reserve( + std::vector issues; + issues.reserve( errors_.size() + errors_truncated); // Reserve space for the transform and an // additional element when truncation occurs. - std::transform(errors_.begin(), errors_.end(), std::back_inserter(messages), - [this](const ParserError& error) { - return cel::DisplayParserError(source_, error); - }); + std::transform( + errors_.begin(), errors_.end(), std::back_inserter(issues), + [this](const ParserError& error) { + auto location = + source_.GetLocation(error.range.begin).value_or(SourceLocation{}); + return cel::ParseIssue(location, error.message); + }); if (errors_truncated) { - messages.emplace_back( - absl::StrCat(error_count_ - 100, " more errors were truncated.")); + issues.push_back(cel::ParseIssue( + absl::StrCat(error_count_ - 100, " more errors were truncated."))); } - return absl::StrJoin(messages, "\n"); + return issues; } void AddMacroCall(int64_t macro_id, absl::string_view function, @@ -602,6 +604,15 @@ Expr ExpressionBalancer::BalancedTree(int lo, int hi) { return factory_.NewCall(ops_[mid], function_, std::move(arguments)); } +std::string FormatIssues(const cel::Source& source, + absl::Span issues) { + return absl::StrJoin( + issues, "\n", [&source](std::string* out, const cel::ParseIssue& issue) { + absl::StrAppend(out, cel::DisplayParserError(source, issue.location(), + issue.message())); + }); +} + class ParserVisitor final : public CelBaseVisitor, public antlr4::BaseErrorListener { public: @@ -673,7 +684,7 @@ class ParserVisitor final : public CelBaseVisitor, const std::string& msg, std::exception_ptr e) override; bool HasErrored() const; - std::string ErrorMessage(); + std::vector CollectIssues(); private: template @@ -1434,7 +1445,9 @@ void ParserVisitor::syntaxError(antlr4::Recognizer* recognizer, bool ParserVisitor::HasErrored() const { return factory_.HasErrors(); } -std::string ParserVisitor::ErrorMessage() { return factory_.ErrorMessage(); } +std::vector ParserVisitor::CollectIssues() { + return factory_.CollectIssues(); +} Expr ParserVisitor::GlobalCallOrMacroImpl(int64_t expr_id, absl::string_view function, @@ -1638,9 +1651,10 @@ struct ParseResult { EnrichedSourceInfo enriched_source_info; }; -absl::StatusOr ParseImpl(const cel::Source& source, - const cel::MacroRegistry& registry, - const ParserOptions& options) { +absl::StatusOr ParseImpl( + const cel::Source& source, const cel::MacroRegistry& registry, + const ParserOptions& options, + std::vector* parse_issues = nullptr) { try { CodePointStream input(source.content(), source.description()); if (input.size() > options.expression_size_codepoint_limit) { @@ -1673,13 +1687,23 @@ absl::StatusOr ParseImpl(const cel::Source& source, expr = ExprFromAny(visitor.visit(parser.start())); } catch (const ParseCancellationException& e) { if (visitor.HasErrored()) { - return absl::InvalidArgumentError(visitor.ErrorMessage()); + auto issues = visitor.CollectIssues(); + std::string error_message = FormatIssues(source, issues); + if (parse_issues != nullptr) { + *parse_issues = std::move(issues); + } + return absl::InvalidArgumentError(error_message); } return absl::CancelledError(e.what()); } if (visitor.HasErrored()) { - return absl::InvalidArgumentError(visitor.ErrorMessage()); + auto issues = visitor.CollectIssues(); + std::string error_message = FormatIssues(source, issues); + if (parse_issues != nullptr) { + *parse_issues = std::move(issues); + } + return absl::InvalidArgumentError(error_message); } return { @@ -1706,10 +1730,12 @@ class ParserImpl : public cel::Parser { macro_registry_(std::move(macro_registry)), library_ids_(std::move(library_ids)) {} - absl::StatusOr> Parse( - const cel::Source& source) const override { + absl::StatusOr> ParseImpl( + const cel::Source& source, + std::vector* parse_issues) const override { CEL_ASSIGN_OR_RETURN(auto parse_result, - ParseImpl(source, macro_registry_, options_)); + ::google::api::expr::parser::ParseImpl( + source, macro_registry_, options_, parse_issues)); return std::make_unique(std::move(parse_result.expr), std::move(parse_result.source_info)); } diff --git a/parser/parser_interface.h b/parser/parser_interface.h index 7cc21ff26..ad6e8ca84 100644 --- a/parser/parser_interface.h +++ b/parser/parser_interface.h @@ -16,10 +16,14 @@ #include #include +#include +#include +#include "absl/base/nullability.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "common/ast.h" #include "common/source.h" #include "parser/macro.h" @@ -73,6 +77,26 @@ class ParserBuilder { virtual absl::StatusOr> Build() = 0; }; +// Information about a parse failure. +class ParseIssue { + public: + explicit ParseIssue(std::string message) : message_(std::move(message)) {} + ParseIssue(SourceLocation location, std::string message) + : location_(location), message_(std::move(message)) {} + + ParseIssue(const ParseIssue& other) = default; + ParseIssue& operator=(const ParseIssue& other) = default; + ParseIssue(ParseIssue&& other) = default; + ParseIssue& operator=(ParseIssue&& other) = default; + + SourceLocation location() const { return location_; } + absl::string_view message() const { return message_; } + + private: + SourceLocation location_; + std::string message_; +}; + // Interface for stateful CEL parser objects for use with a `Compiler` // (bundled parse and type check). This is not needed for most users: // prefer using the free functions in `parser.h` for more flexibility. @@ -81,13 +105,35 @@ class Parser { virtual ~Parser() = default; // Parses the given source into a CEL AST. - virtual absl::StatusOr> Parse( - const cel::Source& source) const = 0; + absl::StatusOr> Parse( + const cel::Source& source) const; + + // Parses the given source into a CEL AST, collecting parse errors in + // `issues`. If `issues` is non-null, it will be cleared and all parse + // issues will be appended to it. + absl::StatusOr> Parse( + const cel::Source& source, std::vector* issues) const; // Returns a builder initialized with the configuration of this parser. virtual std::unique_ptr ToBuilder() const = 0; + + protected: + virtual absl::StatusOr> ParseImpl( + const cel::Source& source, + std::vector* absl_nullable parse_issues) const = 0; }; +inline absl::StatusOr> Parser::Parse( + const cel::Source& source) const { + return ParseImpl(source, nullptr); +} + +inline absl::StatusOr> Parser::Parse( + const cel::Source& source, std::vector* issues) const { + if (issues != nullptr) issues->clear(); + return ParseImpl(source, issues); +} + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_PARSER_PARSER_INTERFACE_H_ diff --git a/parser/parser_test.cc b/parser/parser_test.cc index a1a65481d..587b63a30 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1782,6 +1782,25 @@ TEST(NewParserBuilderTest, ToBuilderPreservesStdlibAndOptionalFromOptions) { EXPECT_FALSE(ast->IsChecked()); } +TEST(ParserTest, ParseFailurePopulatesIssues) { + auto builder = cel::NewParserBuilder(); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("a +", "test.cel")); + std::vector issues; + auto ast_result = parser->Parse(*source, &issues); + EXPECT_THAT(ast_result, Not(IsOk())); + ASSERT_THAT(issues, testing::SizeIs(1)); + EXPECT_THAT(ast_result.status().message(), + HasSubstr("ERROR: test.cel:1:4: Syntax error: mismatched input " + "'' expecting")); + EXPECT_THAT(issues[0].message(), + HasSubstr("Syntax error: mismatched input '' expecting")); + EXPECT_EQ(issues[0].location().line, 1); + // 0-based, but adjusted to 1-based in error message. + EXPECT_EQ(issues[0].location().column, 3); +} + std::string TestName(const testing::TestParamInfo& test_info) { std::string name = absl::StrCat(test_info.index, "-", test_info.param.I); absl::c_replace_if(name, [](char c) { return !absl::ascii_isalnum(c); }, '_'); From 2f06d90f5b593269c2b1f58de3bfd5c8fc2fa895 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 4 May 2026 12:32:22 -0700 Subject: [PATCH 56/88] Add checker support for block. This is needed for re-checking expressions that were produced as a part of policy compilation. PiperOrigin-RevId: 910179322 --- checker/internal/BUILD | 6 + checker/internal/type_checker_impl.cc | 77 ++++++++- checker/internal/type_checker_impl_test.cc | 95 +++++++++++ conformance/BUILD | 5 +- conformance/service.cc | 115 +------------- extensions/BUILD | 5 +- extensions/bindings_ext.cc | 32 +++- extensions/bindings_ext.h | 6 +- testutil/BUILD | 20 +++ testutil/test_macros.cc | 175 +++++++++++++++++++++ testutil/test_macros.h | 33 ++++ 11 files changed, 446 insertions(+), 123 deletions(-) create mode 100644 testutil/test_macros.cc create mode 100644 testutil/test_macros.h diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 1c560cdb9..f4c60f937 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -155,6 +155,7 @@ cc_library( "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -179,6 +180,7 @@ cc_test( "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", + "//common:ast_proto", "//common:container", "//common:decl", "//common:expr", @@ -187,13 +189,17 @@ cc_test( "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", + "//parser", + "//parser:macro_registry", "//testutil:baseline_tests", + "//testutil:test_macros", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 05601fdbb..2472d7def 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -25,6 +25,7 @@ #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" @@ -59,6 +60,15 @@ namespace cel::checker_internal { namespace { +bool MatchesBlock(const Expr& expr) { + if (!expr.has_call_expr()) { + return false; + } + const auto& call = expr.call_expr(); + return call.function() == "cel.@block" && call.args().size() == 2 && + call.args()[0].has_list_expr(); +} + using AstType = cel::TypeSpec; using Severity = TypeCheckIssue::Severity; @@ -204,13 +214,23 @@ class ResolveVisitor : public AstVisitorBase { arena_(arena), current_scope_(&root_scope_) {} - void PreVisitExpr(const Expr& expr) override { expr_stack_.push_back(&expr); } + void PreVisitExpr(const Expr& expr) override { + expr_stack_.push_back(&expr); + if (expr_stack_.size() == 1 && MatchesBlock(expr)) { + ABSL_DCHECK_EQ(expr.call_expr().args().size(), 2); + ABSL_DCHECK(block_init_list_ == nullptr); + block_init_list_ = &expr.call_expr().args()[0]; + } + } void PostVisitExpr(const Expr& expr) override { if (expr_stack_.empty()) { return; } expr_stack_.pop_back(); + if (expr_stack_.size() == 2 && expr_stack_.back() == block_init_list_) { + HandleBlockIndex(&expr); + } } void PostVisitConst(const Expr& expr, const Constant& constant) override; @@ -389,6 +409,7 @@ class ResolveVisitor : public AstVisitorBase { absl::string_view field_name); void HandleOptSelect(const Expr& expr); + void HandleBlockIndex(const Expr* expr); // Get the assigned type of the given subexpression. Should only be called if // the given subexpression is expected to have already been checked. @@ -421,6 +442,7 @@ class ResolveVisitor : public AstVisitorBase { std::vector expr_stack_; absl::flat_hash_map> maybe_namespaced_functions_; + const Expr* block_init_list_ = nullptr; // Select operations that need to be resolved outside of the traversal. // These are handled separately to disambiguate between namespaces and field // accesses @@ -609,8 +631,15 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { } void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) { - // Follows list type inferencing behavior in Go (see map comments above). + if (&expr == block_init_list_) { + // Don't try to coalesce list type here because it can influence the + // resolved type of the list elements. cel.@block is always list and + // the elements are treated independently at runtime. + types_[&expr] = ListType(); + return; + } + // Follows list type inferencing behavior in Go (see map comments above). Type overall_elem_type = inference_context_->InstantiateTypeParams(TypeParamType("E")); auto assignability_context = inference_context_->CreateAssignabilityContext(); @@ -1172,6 +1201,44 @@ void ResolveVisitor::HandleOptSelect(const Expr& expr) { } } +void ResolveVisitor::HandleBlockIndex(const Expr* expr) { + ABSL_DCHECK(block_init_list_ != nullptr); + ABSL_DCHECK(block_init_list_->has_list_expr()); + const auto& elements = block_init_list_->list_expr().elements(); + int index = -1; + for (size_t i = 0; i < elements.size(); ++i) { + if (&elements[i].expr() == expr) { + index = i; + break; + } + } + if (index < 0) { + status_.Update(absl::InternalError( + "could not resolve expression as a cel.@block subexpression")); + return; + } + std::string var_name = absl::StrCat("@index", index); + + // Block is typically manually assembled from logically separate + // expressions so fix the type instead of inferring any remaining free type + // params as for normal subexpressions. + auto type = inference_context_->FinalizeType(GetDeducedType(expr)); + + VariableDecl decl = MakeVariableDecl(var_name, std::move(type)); + + // The C++ runtime requires that the indexes are topologically ordered. + // They just come into scope in order as we walk the AST so we don't need + // to do any additional work to check references to other initializers in + // an init expr. + // + // TODO(uncreated-issue/90): This is slightly inconsistent with the java + // runtime implementation which just requires the references to be acyclic. + auto* scope = + comprehension_vars_.emplace_back(current_scope_->MakeNestedScope()).get(); + scope->InsertVariableIfAbsent(std::move(decl)); + current_scope_ = scope; +} + class ResolveRewriter : public AstRewriterBase { public: explicit ResolveRewriter(const ResolveVisitor& visitor, @@ -1230,15 +1297,15 @@ class ResolveRewriter : public AstRewriterBase { if (auto iter = visitor_.types().find(&expr); iter != visitor_.types().end()) { - auto flattened_type = - FlattenType(inference_context_.FinalizeType(iter->second)); + cel::Type finalized_type = inference_context_.FinalizeType(iter->second); + auto flattened_type = FlattenType(finalized_type); if (!flattened_type.ok()) { status_.Update(flattened_type.status()); return rewritten; } type_map_[expr.id()] = *std::move(flattened_type); - resolved_types_[expr.id()] = iter->second; + resolved_types_[expr.id()] = finalized_type; rewritten = true; } diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index e6cd641d6..893f0689d 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -26,6 +26,7 @@ #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -36,6 +37,7 @@ #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" +#include "common/ast_proto.h" #include "common/container.h" #include "common/decl.h" #include "common/expr.h" @@ -45,7 +47,10 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" +#include "parser/macro_registry.h" +#include "parser/parser.h" #include "testutil/baseline_tests.h" +#include "testutil/test_macros.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" @@ -108,6 +113,17 @@ google::protobuf::Arena* absl_nonnull TestTypeArena() { return &(*kArena); } +absl::StatusOr> MakeTestParsedAstWithMacros( + absl::string_view expression, const cel::MacroRegistry& registry) { + CEL_ASSIGN_OR_RETURN( + auto source, + cel::NewSource(expression, /*description=*/std::string(expression))); + CEL_ASSIGN_OR_RETURN(auto parsed_expr, google::api::expr::parser::Parse( + *source, registry, + {.enable_optional_syntax = true})); + return cel::CreateAstFromParsedExpr(parsed_expr); +} + FunctionDecl MakeIdentFunction() { auto decl = MakeFunctionDecl( "identity", @@ -272,6 +288,12 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena /*return_type=*/TypeType(arena, TypeParamType("A")), TypeParamType("A")))); + Type kParam(TypeParamType("T")); + CEL_ASSIGN_OR_RETURN( + auto block_decl, + MakeFunctionDecl("cel.@block", MakeOverloadDecl("cel_block_list", kParam, + ListType(), kParam))); + env.InsertFunctionIfAbsent(std::move(not_op)); env.InsertFunctionIfAbsent(std::move(not_strictly_false)); env.InsertFunctionIfAbsent(std::move(add_op)); @@ -289,6 +311,7 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena env.InsertFunctionIfAbsent(std::move(to_type)); env.InsertFunctionIfAbsent(std::move(to_duration)); env.InsertFunctionIfAbsent(std::move(to_timestamp)); + env.InsertFunctionIfAbsent(std::move(block_decl)); return absl::OkStatus(); } @@ -308,6 +331,78 @@ TEST(TypeCheckerImplTest, SmokeTest) { EXPECT_THAT(result.GetIssues(), IsEmpty()); } +TEST(TypeCheckerImplTest, BlockMacroSupport) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAstWithMacros( + "cel.block([1, 2], cel.index(0) + cel.index(1))", registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // Overall type should be int. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + auto root_id = checked_ast->root_expr().id(); + EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(), + PrimitiveType::kInt64); +} + +TEST(TypeCheckerImplTest, BlockMacroSupportMixedTypes) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(1))", + registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // cel.index(1) refers to 'a' which is string. + // So overall type should be string. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + auto root_id = checked_ast->root_expr().id(); + EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(), + PrimitiveType::kString); +} + +TEST(TypeCheckerImplTest, BadIndex) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(2))", + registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), + HasSubstr("undeclared reference to '@index2' (in container")); +} + TEST(TypeCheckerImplTest, SimpleIdentsResolved) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); diff --git a/conformance/BUILD b/conformance/BUILD index 139739891..9b527cf35 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -69,6 +69,7 @@ cc_library( "//runtime:reference_resolver", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", + "//testutil:test_macros", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -221,7 +222,7 @@ _TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [ "proto3/set_null/list_value", "proto3/set_null/single_struct", - # cel.@block + # no optional support for legacy types "block_ext/basic/optional_list", "block_ext/basic/optional_map", "block_ext/basic/optional_map_chained", @@ -231,7 +232,7 @@ _TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [ _TESTS_TO_SKIP_CHECKED = [ # block is a post-check optimization that inserts internal variables. The C++ type checker # needs support for a proper optimizer for this to work. - "block_ext", + # "block_ext", ] _TESTS_TO_SKIP_LEGACY_DASHBOARD = [ diff --git a/conformance/service.cc b/conformance/service.cc index 3edc214e6..463334bb5 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -14,7 +14,6 @@ #include "conformance/service.h" -#include #include #include #include @@ -36,11 +35,8 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" #include "checker/optional.h" #include "checker/standard_library.h" #include "checker/type_checker_builder.h" @@ -48,7 +44,6 @@ #include "common/ast.h" #include "common/ast_proto.h" #include "common/decl_proto_v1alpha1.h" -#include "common/expr.h" #include "common/internal/value_conversion.h" #include "common/source.h" #include "common/value.h" @@ -72,8 +67,6 @@ #include "extensions/select_optimization.h" #include "extensions/strings.h" #include "internal/status_macros.h" -#include "parser/macro.h" -#include "parser/macro_expr_factory.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "parser/parser.h" @@ -85,6 +78,7 @@ #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" +#include "testutil/test_macros.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto2/test_all_types_extensions.pb.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" @@ -106,109 +100,6 @@ namespace google::api::expr::runtime { namespace { -bool IsCelNamespace(const cel::Expr& target) { - return target.has_ident_expr() && target.ident_expr().name() == "cel"; -} - -absl::optional CelBlockMacroExpander(cel::MacroExprFactory& factory, - cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& bindings_arg = args[0]; - if (!bindings_arg.has_list_expr()) { - return factory.ReportErrorAt( - bindings_arg, "cel.block requires the first arg to be a list literal"); - } - return factory.NewCall("cel.@block", args); -} - -absl::optional CelIndexMacroExpander(cel::MacroExprFactory& factory, - cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& index_arg = args[0]; - if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { - return factory.ReportErrorAt( - index_arg, "cel.index requires a single non-negative int constant arg"); - } - int64_t index = index_arg.const_expr().int_value(); - if (index < 0) { - return factory.ReportErrorAt( - index_arg, "cel.index requires a single non-negative int constant arg"); - } - return factory.NewIdent(absl::StrCat("@index", index)); -} - -absl::optional CelIterVarMacroExpander( - cel::MacroExprFactory& factory, cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& depth_arg = args[0]; - if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || - depth_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - depth_arg, "cel.iterVar requires two non-negative int constant args"); - } - cel::Expr& unique_arg = args[1]; - if (!unique_arg.has_const_expr() || - !unique_arg.const_expr().has_int_value() || - unique_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - unique_arg, "cel.iterVar requires two non-negative int constant args"); - } - return factory.NewIdent( - absl::StrCat("@it:", depth_arg.const_expr().int_value(), ":", - unique_arg.const_expr().int_value())); -} - -absl::optional CelAccuVarMacroExpander( - cel::MacroExprFactory& factory, cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& depth_arg = args[0]; - if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || - depth_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - depth_arg, "cel.accuVar requires two non-negative int constant args"); - } - cel::Expr& unique_arg = args[1]; - if (!unique_arg.has_const_expr() || - !unique_arg.const_expr().has_int_value() || - unique_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - unique_arg, "cel.accuVar requires two non-negative int constant args"); - } - return factory.NewIdent( - absl::StrCat("@ac:", depth_arg.const_expr().int_value(), ":", - unique_arg.const_expr().int_value())); -} - -absl::Status RegisterCelBlockMacros(cel::MacroRegistry& registry) { - CEL_ASSIGN_OR_RETURN(auto block_macro, - cel::Macro::Receiver("block", 2, CelBlockMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(block_macro)); - CEL_ASSIGN_OR_RETURN(auto index_macro, - cel::Macro::Receiver("index", 1, CelIndexMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(index_macro)); - CEL_ASSIGN_OR_RETURN( - auto iter_var_macro, - cel::Macro::Receiver("iterVar", 2, CelIterVarMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(iter_var_macro)); - CEL_ASSIGN_OR_RETURN( - auto accu_var_macro, - cel::Macro::Receiver("accuVar", 2, CelAccuVarMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(accu_var_macro)); - return absl::OkStatus(); -} - google::rpc::Code ToGrpcCode(absl::StatusCode code) { return static_cast(code); } @@ -250,7 +141,7 @@ absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, CEL_RETURN_IF_ERROR(cel::extensions::RegisterBindingsMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtoMacros(macros, options)); - CEL_RETURN_IF_ERROR(RegisterCelBlockMacros(macros)); + CEL_RETURN_IF_ERROR(cel::test::RegisterTestMacros(macros)); CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(request.cel_source(), request.source_location())); CEL_ASSIGN_OR_RETURN(auto parsed_expr, @@ -285,6 +176,8 @@ absl::Status CheckImpl(google::protobuf::Arena* arena, if (!request.no_std_env()) { CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCheckerLibrary())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCheckerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCheckerLibrary())); CEL_RETURN_IF_ERROR( builder->AddLibrary(cel::extensions::StringsCheckerLibrary())); CEL_RETURN_IF_ERROR( diff --git a/extensions/BUILD b/extensions/BUILD index ff37e2c3f..05104a4a5 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -215,7 +215,10 @@ cc_library( srcs = ["bindings_ext.cc"], hdrs = ["bindings_ext.h"], deps = [ - "//common:ast", + "//checker:type_checker_builder", + "//common:decl", + "//common:expr", + "//common:type", "//compiler", "//internal:status_macros", "//parser:macro", diff --git a/extensions/bindings_ext.cc b/extensions/bindings_ext.cc index f097709ca..c59f724bd 100644 --- a/extensions/bindings_ext.cc +++ b/extensions/bindings_ext.cc @@ -21,7 +21,10 @@ #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" -#include "common/ast.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/type.h" #include "compiler/compiler.h" #include "internal/status_macros.h" #include "parser/macro.h" @@ -34,6 +37,8 @@ namespace { static constexpr char kCelNamespace[] = "cel"; static constexpr char kBind[] = "bind"; +static constexpr char kBlock[] = "cel.@block"; +static constexpr char kBlockOverloadId[] = "cel_block_list"; static constexpr char kUnusedIterVar[] = "#unused"; bool IsTargetNamespace(const Expr& target) { @@ -47,6 +52,19 @@ inline absl::Status ConfigureParser(ParserBuilder& parser_builder) { return absl::OkStatus(); } +absl::Status ConfigureChecker(int version, + TypeCheckerBuilder& type_checker_builder) { + if (version < 1) { + return absl::OkStatus(); + } + static Type kParam(TypeParamType("T")); + CEL_ASSIGN_OR_RETURN( + auto decl, + MakeFunctionDecl(kBlock, MakeOverloadDecl(kBlockOverloadId, kParam, + ListType(), kParam))); + return type_checker_builder.AddFunction(std::move(decl)); +} + } // namespace std::vector bindings_macros() { @@ -70,8 +88,16 @@ std::vector bindings_macros() { return {*cel_bind}; } -CompilerLibrary BindingsCompilerLibrary() { - return CompilerLibrary("cel.lib.ext.bindings", &ConfigureParser); +CompilerLibrary BindingsCompilerLibrary(int version) { + return CompilerLibrary( + "cel.lib.ext.bindings", &ConfigureParser, + [version](auto& b) { return ConfigureChecker(version, b); }); +} + +CheckerLibrary BindingsCheckerLibrary(int version) { + return CheckerLibrary{"cel.lib.ext.bindings", [version](auto& b) { + return ConfigureChecker(version, b); + }}; } } // namespace cel::extensions diff --git a/extensions/bindings_ext.h b/extensions/bindings_ext.h index a338b24f6..40b83a37f 100644 --- a/extensions/bindings_ext.h +++ b/extensions/bindings_ext.h @@ -25,6 +25,7 @@ namespace cel::extensions { +constexpr int kBindingsVersionLatest = 1; // bindings_macros() returns a macro for cel.bind() which can be used to support // local variable bindings within expressions. std::vector bindings_macros(); @@ -35,7 +36,10 @@ inline absl::Status RegisterBindingsMacros(MacroRegistry& registry, } // Declarations for the bindings extension library. -CompilerLibrary BindingsCompilerLibrary(); +CompilerLibrary BindingsCompilerLibrary(int version = kBindingsVersionLatest); + +// Declarations for the bindings extension library. +CheckerLibrary BindingsCheckerLibrary(int version = kBindingsVersionLatest); } // namespace cel::extensions diff --git a/testutil/BUILD b/testutil/BUILD index 292696033..782c95ca6 100644 --- a/testutil/BUILD +++ b/testutil/BUILD @@ -62,6 +62,26 @@ cc_library( deps = ["//internal:proto_matchers"], ) +cc_library( + name = "test_macros", + testonly = True, + srcs = ["test_macros.cc"], + hdrs = ["test_macros.h"], + deps = [ + "//common:expr", + "//internal:status_macros", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:macro_registry", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "baseline_tests", testonly = True, diff --git a/testutil/test_macros.cc b/testutil/test_macros.cc new file mode 100644 index 000000000..158135762 --- /dev/null +++ b/testutil/test_macros.cc @@ -0,0 +1,175 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "testutil/test_macros.h" + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/expr.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/macro_registry.h" + +namespace cel::test { + +namespace { + +bool IsCelNamespace(const Expr& target) { + return target.has_ident_expr() && target.ident_expr().name() == "cel"; +} + +absl::optional CelBlockMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& bindings_arg = args[0]; + if (!bindings_arg.has_list_expr()) { + return factory.ReportErrorAt( + bindings_arg, "cel.block requires the first arg to be a list literal"); + } + return factory.NewCall("cel.@block", args); +} + +absl::optional CelIndexMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& index_arg = args[0]; + if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { + return factory.ReportErrorAt( + index_arg, "cel.index requires a single non-negative int constant arg"); + } + int64_t index = index_arg.const_expr().int_value(); + if (index < 0) { + return factory.ReportErrorAt( + index_arg, "cel.index requires a single non-negative int constant arg"); + } + return factory.NewIdent(absl::StrCat("@index", index)); +} + +absl::optional CelIterVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& depth_arg = args[0]; + if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || + depth_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + depth_arg, "cel.iterVar requires two non-negative int constant args"); + } + Expr& unique_arg = args[1]; + if (!unique_arg.has_const_expr() || + !unique_arg.const_expr().has_int_value() || + unique_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + unique_arg, "cel.iterVar requires two non-negative int constant args"); + } + return factory.NewIdent( + absl::StrCat("@it:", depth_arg.const_expr().int_value(), ":", + unique_arg.const_expr().int_value())); +} + +absl::optional CelAccuVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& depth_arg = args[0]; + if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || + depth_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + depth_arg, "cel.accuVar requires two non-negative int constant args"); + } + Expr& unique_arg = args[1]; + if (!unique_arg.has_const_expr() || + !unique_arg.const_expr().has_int_value() || + unique_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + unique_arg, "cel.accuVar requires two non-negative int constant args"); + } + return factory.NewIdent( + absl::StrCat("@ac:", depth_arg.const_expr().int_value(), ":", + unique_arg.const_expr().int_value())); +} + +Macro MakeCelBlockMacro() { + auto macro_or_status = Macro::Receiver("block", 2, CelBlockMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelIndexMacro() { + auto macro_or_status = Macro::Receiver("index", 1, CelIndexMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelIterVarMacro() { + auto macro_or_status = Macro::Receiver("iterVar", 2, CelIterVarMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelAccuVarMacro() { + auto macro_or_status = Macro::Receiver("accuVar", 2, CelAccuVarMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +} // namespace + +const Macro& CelBlockMacro() { + static const absl::NoDestructor macro(MakeCelBlockMacro()); + return *macro; +} + +const Macro& CelIndexMacro() { + static const absl::NoDestructor macro(MakeCelIndexMacro()); + return *macro; +} + +const Macro& CelIterVarMacro() { + static const absl::NoDestructor macro(MakeCelIterVarMacro()); + return *macro; +} + +const Macro& CelAccuVarMacro() { + static const absl::NoDestructor macro(MakeCelAccuVarMacro()); + return *macro; +} + +absl::Status RegisterTestMacros(MacroRegistry& registry) { + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelBlockMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelIndexMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelIterVarMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelAccuVarMacro())); + return absl::OkStatus(); +} + +} // namespace cel::test diff --git a/testutil/test_macros.h b/testutil/test_macros.h new file mode 100644 index 000000000..cad897999 --- /dev/null +++ b/testutil/test_macros.h @@ -0,0 +1,33 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ +#define THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ + +#include "absl/status/status.h" +#include "parser/macro.h" +#include "parser/macro_registry.h" + +namespace cel::test { + +const Macro& CelBlockMacro(); +const Macro& CelIndexMacro(); +const Macro& CelIterVarMacro(); +const Macro& CelAccuVarMacro(); + +absl::Status RegisterTestMacros(MacroRegistry& registry); + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ From ddcece1479ed78ccde1594e47d94eeb841de115f Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 6 May 2026 13:56:21 -0700 Subject: [PATCH 57/88] Refactor optional dispatch tables. PiperOrigin-RevId: 911533283 --- common/values/optional_value.cc | 255 ++++++++++++++++---------------- 1 file changed, 124 insertions(+), 131 deletions(-) diff --git a/common/values/optional_value.cc b/common/values/optional_value.cc index ad0a65efb..688cf8fb0 100644 --- a/common/values/optional_value.cc +++ b/common/values/optional_value.cc @@ -122,200 +122,185 @@ absl::Status OptionalValueEqual( return absl::OkStatus(); } +google::protobuf::Arena* absl_nullable OptionalValueGetArenaNull( + const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) { + return nullptr; +} + +OpaqueValue OptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) { + return common_internal::MakeOptionalValue(dispatcher, content); +} + +bool OptionalValueHasNoValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content) { + return false; +} + +void EmptyOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = + ErrorValue(absl::FailedPreconditionError("optional.none() dereference")); +} + +void NullOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = NullValue(); +} + +void BoolOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = BoolValue(content.To()); +} + +void IntOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = IntValue(content.To()); +} + +void UintOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UintValue(content.To()); +} + +void DoubleOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = DoubleValue(content.To()); +} + +void DurationOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UnsafeDurationValue(content.To()); +} + +void TimestampOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UnsafeTimestampValue(content.To()); +} + ABSL_CONST_INIT const OptionalValueDispatcher empty_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, - }, - [](const OptionalValueDispatcher* absl_nonnull dispatcher, - CustomValueContent content) -> bool { return false; }, - [](const OptionalValueDispatcher* absl_nonnull dispatcher, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = ErrorValue( - absl::FailedPreconditionError("optional.none() dereference")); + .clone = &OptionalValueClone, }, + &OptionalValueHasNoValue, + &EmptyOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher null_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent, - cel::Value* absl_nonnull result) -> void { *result = NullValue(); }, + &NullOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher bool_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = BoolValue(content.To()); - }, + &BoolOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher int_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = IntValue(content.To()); - }, + &IntOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher uint_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = UintValue(content.To()); - }, + &UintOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher double_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = DoubleValue(content.To()); - }, + &DoubleOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher duration_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = UnsafeDurationValue(content.To()); - }, + &DurationOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher timestamp_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = UnsafeTimestampValue(content.To()); - }, + &TimestampOptionalValueValue, }; struct OptionalValueContent { @@ -323,43 +308,51 @@ struct OptionalValueContent { google::protobuf::Arena* absl_nonnull arena; }; +google::protobuf::Arena* absl_nullable GenericOptionalValueGetArena( + const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent content) { + return content.To().arena; +} + +OpaqueValue GenericOptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena); + +void GenericOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = *content.To().value; +} + ABSL_CONST_INIT const OptionalValueDispatcher optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = - [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent content) -> google::protobuf::Arena* absl_nullable { - return content.To().arena; - }, + .get_arena = &GenericOptionalValueGetArena, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - ABSL_DCHECK(arena != nullptr); - - cel::Value* absl_nonnull result = ::new ( - arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) - cel::Value( - content.To().value->Clone(arena)); - if (!ArenaTraits<>::trivially_destructible(result)) { - arena->OwnDestructor(result); - } - return common_internal::MakeOptionalValue( - &optional_value_dispatcher, - OpaqueValueContent::From( - OptionalValueContent{.value = result, .arena = arena})); - }, + .clone = &GenericOptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = *content.To().value; - }, + &GenericOptionalValueValue, }; +OpaqueValue GenericOptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(arena != nullptr); + + cel::Value* absl_nonnull result = + ::new (arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) + cel::Value(content.To().value->Clone(arena)); + if (!ArenaTraits<>::trivially_destructible(result)) { + arena->OwnDestructor(result); + } + return common_internal::MakeOptionalValue( + &optional_value_dispatcher, OpaqueValueContent::From(OptionalValueContent{ + .value = result, .arena = arena})); +} + } // namespace OptionalValue OptionalValue::Of(cel::Value value, From 5806d30ba86ca40d8ab111e59fa78983afe5319c Mon Sep 17 00:00:00 2001 From: Justin King Date: Fri, 8 May 2026 13:03:46 -0700 Subject: [PATCH 58/88] Update conformance test skip list PiperOrigin-RevId: 912656156 --- conformance/BUILD | 20 ++++++++++++++++++++ conformance/run.bzl | 6 +++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index 9b527cf35..726a11b0b 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -201,6 +201,26 @@ _TESTS_TO_SKIP = [ # precision to preserve value. Not available on older compilers where we just use absl::Format. # We should probably update the spec to allow different formats that parse to the same value. "conversions/string/double_hard", + + # Recent changes + "proto2/set_null/repeated_field_timestamp_null_pruned", + "proto2/set_null/repeated_field_duration_null_pruned", + "proto2/set_null/repeated_field_wrapper_null_pruned", + "proto2/set_null/map_timestamp_null_pruned", + "proto2/set_null/map_duration_null_pruned", + "proto2/set_null/map_wrapper_null_pruned", + "proto3/set_null/repeated_field_timestamp_null_pruned", + "proto3/set_null/repeated_field_duration_null_pruned", + "proto3/set_null/repeated_field_wrapper_null_pruned", + "proto3/set_null/map_timestamp_null_pruned", + "proto3/set_null/map_duration_null_pruned", + "proto3/set_null/map_wrapper_null_pruned", + "string_ext/format/default precision for fixed-point clause with int", + "string_ext/format/default precision for fixed-point clause with uint", + "string_ext/format/default precision for scientific notation with int", + "string_ext/format/default precision for scientific notation with uint", + "namespace/namespace_shadowing/basic", + "namespace/namespace_shadowing/comprehension_shadowing_namespaced_selector_disambiguation", ] _TESTS_TO_SKIP_MODERN = _TESTS_TO_SKIP diff --git a/conformance/run.bzl b/conformance/run.bzl index 4fcf325c6..d53fd539c 100644 --- a/conformance/run.bzl +++ b/conformance/run.bzl @@ -70,7 +70,7 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, args.append("--skip_check") else: args.append("--noskip_check") - args.append("--skip_tests={}".format(",".join(_expand_tests_to_skip(skip_tests)))) + args.append("--skip_tests=\"{}\"".format(",".join(_expand_tests_to_skip(skip_tests)))) if dashboard: args.append("--dashboard") return args @@ -80,8 +80,8 @@ def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_ name = _conformance_test_name(name, optimize, recursive), args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard) + ["$(location " + test + ")" for test in data] + select( { - "@platforms//os:windows": ["--skip_tests={}".format(",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS))], - "//conditions:default": ["--skip_tests={}".format(",".join(skip_tests))], + "@platforms//os:windows": ["--skip_tests=\"{}\"".format(",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS))], + "//conditions:default": ["--skip_tests=\"{}\"".format(",".join(skip_tests))], }, ), data = data, From cb9dc8a2e71e503655b1992bdba3debc7fda12a7 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 8 May 2026 15:45:31 -0700 Subject: [PATCH 59/88] Fix command line argument splitting issue for conformance tests. PiperOrigin-RevId: 912731724 --- conformance/run.bzl | 10 +++++----- conformance/run.cc | 8 ++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/conformance/run.bzl b/conformance/run.bzl index d53fd539c..15850b0aa 100644 --- a/conformance/run.bzl +++ b/conformance/run.bzl @@ -56,7 +56,7 @@ def _conformance_test_name(name, optimize, recursive): ], ) -def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard): +def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard): args = [] if modern: args.append("--modern") @@ -70,7 +70,6 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, args.append("--skip_check") else: args.append("--noskip_check") - args.append("--skip_tests=\"{}\"".format(",".join(_expand_tests_to_skip(skip_tests)))) if dashboard: args.append("--dashboard") return args @@ -78,10 +77,11 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard): cc_test( name = _conformance_test_name(name, optimize, recursive), - args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard) + ["$(location " + test + ")" for test in data] + select( + args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard) + ["$(location " + test + ")" for test in data], + env = select( { - "@platforms//os:windows": ["--skip_tests=\"{}\"".format(",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS))], - "//conditions:default": ["--skip_tests=\"{}\"".format(",".join(skip_tests))], + "@platforms//os:windows": {"CEL_SKIP_TESTS": ",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS)}, + "//conditions:default": {"CEL_SKIP_TESTS": ",".join(skip_tests)}, }, ), data = data, diff --git a/conformance/run.cc b/conformance/run.cc index d5a919d76..80164d9a4 100644 --- a/conformance/run.cc +++ b/conformance/run.cc @@ -42,6 +42,7 @@ #include "absl/strings/cord.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "absl/types/span.h" @@ -273,6 +274,13 @@ int main(int argc, char** argv) { { auto service = NewConformanceServiceFromFlags(); auto tests_to_skip = absl::GetFlag(FLAGS_skip_tests); + if (const char* env_skip = std::getenv("CEL_SKIP_TESTS"); + env_skip != nullptr) { + for (absl::string_view test : + absl::StrSplit(env_skip, ',', absl::SkipEmpty())) { + tests_to_skip.push_back(std::string(test)); + } + } for (int argi = 1; argi < argc; argi++) { ABSL_CHECK_OK(RegisterTestsFromFile(service, tests_to_skip, absl::string_view(argv[argi]))); From cf31ddf620b9d809014418e82428863b54190cbb Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 11 May 2026 10:50:27 -0700 Subject: [PATCH 60/88] Introduce `Bind` expression factory helper PiperOrigin-RevId: 913778503 --- common/expr_factory.h | 23 ++++++++++++++ parser/macro_expr_factory_test.cc | 51 +++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/common/expr_factory.h b/common/expr_factory.h index b9769b457..773217ad9 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -352,6 +352,29 @@ class ExprFactory { return expr; } + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + Expr NewBind(NextIdFunc next_id, BindVar bind_var, BindExpr bind_expr, + RestExpr rest_expr) { + Expr expr; + expr.set_id(next_id()); + auto& comprehension_expr = expr.mutable_comprehension_expr(); + comprehension_expr.set_iter_var("#unused"); + comprehension_expr.set_iter_range( + NewList(next_id(), std::vector{})); + comprehension_expr.set_accu_var(bind_var); + comprehension_expr.set_accu_init(std::move(bind_expr)); + comprehension_expr.set_loop_condition(NewBoolConst(next_id(), false)); + comprehension_expr.set_loop_step(NewIdent(next_id(), bind_var)); + comprehension_expr.set_result(std::move(rest_expr)); + return expr; + } + private: friend class MacroExprFactory; friend class ParserMacroExprFactory; diff --git a/parser/macro_expr_factory_test.cc b/parser/macro_expr_factory_test.cc index 489538be1..b95cbe16f 100644 --- a/parser/macro_expr_factory_test.cc +++ b/parser/macro_expr_factory_test.cc @@ -15,6 +15,7 @@ #include "parser/macro_expr_factory.h" #include +#include #include #include "absl/strings/string_view.h" @@ -39,6 +40,7 @@ class TestMacroExprFactory final : public MacroExprFactory { return NewUnspecified(NextId()); } + using MacroExprFactory::NewBind; using MacroExprFactory::NewBoolConst; using MacroExprFactory::NewCall; using MacroExprFactory::NewComprehension; @@ -69,6 +71,8 @@ class TestMacroExprFactory final : public MacroExprFactory { namespace { +using ::testing::IsEmpty; + TEST(MacroExprFactory, CopyUnspecified) { TestMacroExprFactory factory; EXPECT_EQ(factory.Copy(factory.NewUnspecified()), factory.NewUnspecified(2)); @@ -147,5 +151,52 @@ TEST(MacroExprFactory, CopyComprehension) { factory.NewIdent(11, "foo"), factory.NewIdent(12, "bar"))); } +TEST(MacroExprFactory, NewBind) { + TestMacroExprFactory factory; + Expr bind_expr = factory.NewIdent(10, "x"); + Expr rest_expr = factory.NewIdent(20, "y"); + + auto next_id = [id = 100]() mutable { return id++; }; + + Expr expr = + factory.NewBind(next_id, "a", std::move(bind_expr), std::move(rest_expr)); + + EXPECT_EQ(expr.id(), 100); + ASSERT_TRUE(expr.has_comprehension_expr()); + + const auto& comp = expr.comprehension_expr(); + EXPECT_EQ(comp.iter_var(), "#unused"); + + ASSERT_TRUE(comp.has_iter_range()); + EXPECT_EQ(comp.iter_range().id(), 101); + EXPECT_EQ(comp.iter_range().kind_case(), ExprKindCase::kListExpr); + EXPECT_THAT(comp.iter_range().list_expr().elements(), IsEmpty()); + + EXPECT_EQ(comp.accu_var(), "a"); + + ASSERT_TRUE(comp.has_accu_init()); + Expr expected_bind_expr; + expected_bind_expr.set_id(10); + expected_bind_expr.mutable_ident_expr().set_name("x"); + EXPECT_EQ(comp.accu_init(), expected_bind_expr); + + ASSERT_TRUE(comp.has_loop_condition()); + EXPECT_EQ(comp.loop_condition().id(), 102); + EXPECT_EQ(comp.loop_condition().kind_case(), ExprKindCase::kConstant); + EXPECT_TRUE(comp.loop_condition().const_expr().has_bool_value()); + EXPECT_FALSE(comp.loop_condition().const_expr().bool_value()); + + ASSERT_TRUE(comp.has_loop_step()); + EXPECT_EQ(comp.loop_step().id(), 103); + EXPECT_EQ(comp.loop_step().kind_case(), ExprKindCase::kIdentExpr); + EXPECT_EQ(comp.loop_step().ident_expr().name(), "a"); + + ASSERT_TRUE(comp.has_result()); + Expr expected_rest_expr; + expected_rest_expr.set_id(20); + expected_rest_expr.mutable_ident_expr().set_name("y"); + EXPECT_EQ(comp.result(), expected_rest_expr); +} + } // namespace } // namespace cel From 2e6e9ff4493bfbe0baf883107f3fb7ce6f675d88 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 11 May 2026 21:06:47 -0700 Subject: [PATCH 61/88] Add support for abbreviations and aliases in container configuration for CEL C++ environment YAML. This allows specifying name, abbreviations, and aliases in a container config instead of just a string. The string syntax is preserved as an alternative PiperOrigin-RevId: 914038623 --- env/BUILD | 1 + env/config.h | 11 +++- env/env.cc | 12 +++- env/env_test.cc | 30 ++++++++++ env/env_yaml.cc | 107 +++++++++++++++++++++++++++++++-- env/env_yaml_test.cc | 139 ++++++++++++++++++++++++++++++++++++++++++- 6 files changed, 289 insertions(+), 11 deletions(-) diff --git a/env/BUILD b/env/BUILD index 55297b190..41ffc1723 100644 --- a/env/BUILD +++ b/env/BUILD @@ -52,6 +52,7 @@ cc_library( ":config", "//checker:type_checker_builder", "//common:constant", + "//common:container", "//common:decl", "//common:type", "//compiler", diff --git a/env/config.h b/env/config.h index 10b23d030..e427832ff 100644 --- a/env/config.h +++ b/env/config.h @@ -34,9 +34,16 @@ class Config { struct ContainerConfig { std::string name; - // TODO(uncreated-issue/87): add support for aliases and abbreviations. + std::vector abbreviations; + struct Alias { + std::string alias; + std::string qualified_name; + }; + std::vector aliases; - bool IsEmpty() const { return name.empty(); } + bool IsEmpty() const { + return name.empty() && abbreviations.empty() && aliases.empty(); + } }; void SetContainerConfig(ContainerConfig container_config) { diff --git a/env/env.cc b/env/env.cc index 5a4198497..42652ce59 100644 --- a/env/env.cc +++ b/env/env.cc @@ -24,6 +24,7 @@ #include "absl/strings/string_view.h" #include "checker/type_checker_builder.h" #include "common/constant.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "compiler/compiler.h" @@ -130,7 +131,16 @@ absl::StatusOr> Env::NewCompilerBuilder() { cel::TypeCheckerBuilder& checker_builder = compiler_builder->GetCheckerBuilder(); - checker_builder.set_container(config_.GetContainerConfig().name); + ExpressionContainer container; + CEL_RETURN_IF_ERROR( + container.SetContainer(config_.GetContainerConfig().name)); + for (const auto& abbr : config_.GetContainerConfig().abbreviations) { + CEL_RETURN_IF_ERROR(container.AddAbbreviation(abbr)); + } + for (const auto& alias : config_.GetContainerConfig().aliases) { + CEL_RETURN_IF_ERROR(container.AddAlias(alias.alias, alias.qualified_name)); + } + checker_builder.SetExpressionContainer(std::move(container)); if (!config_.GetStandardLibraryConfig().disable) { CEL_RETURN_IF_ERROR( diff --git a/env/env_test.cc b/env/env_test.cc index 076eb57bc..b599aa569 100644 --- a/env/env_test.cc +++ b/env/env_test.cc @@ -314,6 +314,36 @@ TEST(ContainerConfigTest, ContainerConfig) { EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); } +TEST(ContainerConfigTest, ContainerConfigWithAbbreviations) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig( + {.name = "cel.expr.conformance", + .abbreviations = {"cel.expr.conformance.proto2.TestAllTypes"}}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("TestAllTypes{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + +TEST(ContainerConfigTest, ContainerConfigWithAliases) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig( + {.name = "cel.expr.conformance", + .aliases = { + {.alias = "MyTestType", + .qualified_name = "cel.expr.conformance.proto2.TestAllTypes"}}}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("MyTestType{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + struct VariableConfigWithValueTestCase { Config::VariableConfig variable_config; std::string validate_type_expr; diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 4ba16ea84..159786598 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -150,12 +150,72 @@ absl::Status ParseName(Config& config, absl::string_view yaml, absl::Status ParseContainerConfig(Config& config, absl::string_view yaml, const YAML::Node& root) { const YAML::Node container = root["container"]; - if (container.IsDefined()) { - if (!container.IsScalar()) { - return YamlError(yaml, container, "Node 'container' is not a string"); - } + if (!container.IsDefined()) { + return absl::OkStatus(); + } + + if (container.IsScalar()) { config.SetContainerConfig({.name = GetString(yaml, container)}); + return absl::OkStatus(); } + + if (!container.IsMap()) { + return YamlError(yaml, container, + "Node 'container' is neither a string nor a map"); + } + + Config::ContainerConfig container_config; + + const YAML::Node name = container["name"]; + if (name.IsDefined()) { + if (!name.IsScalar()) { + return YamlError(yaml, name, "Node 'name' in container is not a string"); + } + container_config.name = GetString(yaml, name); + } + + const YAML::Node abbreviations = container["abbreviations"]; + if (abbreviations.IsDefined()) { + if (!abbreviations.IsSequence()) { + return YamlError(yaml, abbreviations, + "Node 'abbreviations' is not a sequence"); + } + for (const YAML::Node& abbr : abbreviations) { + if (!abbr.IsScalar()) { + return YamlError(yaml, abbr, "Abbreviation is not a string"); + } + container_config.abbreviations.push_back(GetString(yaml, abbr)); + } + } + + const YAML::Node aliases = container["aliases"]; + if (aliases.IsDefined()) { + if (!aliases.IsSequence()) { + return YamlError(yaml, aliases, "Node 'aliases' is not a sequence"); + } + for (const YAML::Node& alias_node : aliases) { + if (!alias_node.IsMap()) { + return YamlError(yaml, alias_node, "Alias entry is not a map"); + } + const YAML::Node alias_key = alias_node["alias"]; + const YAML::Node qualified_name_key = alias_node["qualified_name"]; + + if (!alias_key.IsDefined() || !alias_key.IsScalar()) { + return YamlError(yaml, alias_node, + "Alias entry missing 'alias' string"); + } + if (!qualified_name_key.IsDefined() || !qualified_name_key.IsScalar()) { + return YamlError(yaml, alias_node, + "Alias entry missing 'qualified_name' string"); + } + + container_config.aliases.push_back( + {.alias = GetString(yaml, alias_key), + .qualified_name = GetString(yaml, qualified_name_key)}); + } + } + + config.SetContainerConfig(std::move(container_config)); return absl::OkStatus(); } @@ -686,7 +746,44 @@ void EmitContainerConfig(const Config& env_config, YAML::Emitter& out) { } out << YAML::Key << "container"; - out << YAML::Value << YAML::DoubleQuoted << container_config.name; + if (container_config.abbreviations.empty() && + container_config.aliases.empty()) { + out << YAML::Value << YAML::DoubleQuoted << container_config.name; + } else { + out << YAML::Value << YAML::BeginMap; + if (!container_config.name.empty()) { + out << YAML::Key << "name" << YAML::Value << YAML::DoubleQuoted + << container_config.name; + } + if (!container_config.abbreviations.empty()) { + std::vector sorted_abbrs = container_config.abbreviations; + absl::c_sort(sorted_abbrs); + out << YAML::Key << "abbreviations" << YAML::Value << YAML::BeginSeq; + for (const auto& abbr : sorted_abbrs) { + out << YAML::Value << YAML::DoubleQuoted << abbr; + } + out << YAML::EndSeq; + } + if (!container_config.aliases.empty()) { + std::vector sorted_aliases = + container_config.aliases; + absl::c_sort(sorted_aliases, [](const Config::ContainerConfig::Alias& a, + const Config::ContainerConfig::Alias& b) { + return a.alias < b.alias; + }); + out << YAML::Key << "aliases" << YAML::Value << YAML::BeginSeq; + for (const auto& alias : sorted_aliases) { + out << YAML::BeginMap; + out << YAML::Key << "alias" << YAML::Value << YAML::DoubleQuoted + << alias.alias; + out << YAML::Key << "qualified_name" << YAML::Value + << YAML::DoubleQuoted << alias.qualified_name; + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } } void EmitExtensionConfigs(const Config& env_config, YAML::Emitter& out) { diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index 25cc63206..d19c0dbfb 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -55,6 +55,31 @@ TEST(EnvYamlTest, ParseContainerConfig) { Field(&Config::ContainerConfig::name, "test.container")); } +TEST(EnvYamlTest, ParseContainerConfig_AlternativeSyntax) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + container: + name: test.container + abbreviations: + - abbr1.Abbr1 + - abbr2.Abbr2 + aliases: + - alias: alias1 + qualified_name: qual.name1 + - alias: alias2 + qualified_name: qual.name2 + )yaml")); + + const auto& container_config = config.GetContainerConfig(); + EXPECT_EQ(container_config.name, "test.container"); + EXPECT_THAT(container_config.abbreviations, + UnorderedElementsAre("abbr1.Abbr1", "abbr2.Abbr2")); + ASSERT_THAT(container_config.aliases, SizeIs(2)); + EXPECT_EQ(container_config.aliases[0].alias, "alias1"); + EXPECT_EQ(container_config.aliases[0].qualified_name, "qual.name1"); + EXPECT_EQ(container_config.aliases[1].alias, "alias2"); + EXPECT_EQ(container_config.aliases[1].qualified_name, "qual.name2"); +} + TEST(EnvYamlTest, ParseExtensionConfigs) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( extensions: @@ -550,9 +575,78 @@ INSTANTIATE_TEST_SUITE_P( container: - error: "error" )yaml", - .expected_error = "3:19: Node 'container' is not a string\n" - "| - error: \"error\"\n" - "| ^", + .expected_error = + "3:19: Node 'container' is neither a string nor a map\n" + "| - error: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + name: [] + )yaml", + .expected_error = "3:25: Node 'name' in container is not a string\n" + "| name: []\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + abbreviations: "abbr" + )yaml", + .expected_error = "3:34: Node 'abbreviations' is not a sequence\n" + "| abbreviations: \"abbr\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + abbreviations: + - [] + )yaml", + .expected_error = "4:21: Abbreviation is not a string\n" + "| - []\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: "not a sequence" + )yaml", + .expected_error = "3:28: Node 'aliases' is not a sequence\n" + "| aliases: \"not a sequence\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - "not a map" + )yaml", + .expected_error = "4:21: Alias entry is not a map\n" + "| - \"not a map\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - qualified_name: "qual" + )yaml", + .expected_error = "4:21: Alias entry missing 'alias' string\n" + "| - qualified_name: \"qual\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - alias: "my_alias" + )yaml", + .expected_error = "4:21: Alias entry missing" + " 'qualified_name' string\n" + "| - alias: \"my_alias\"\n" + "| ^", }, ParseTestCase{ .yaml = R"yaml( @@ -946,6 +1040,33 @@ std::vector GetExportTestCases() { container: "test.container" )yaml", }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + config.SetName("test.env"); + config.SetContainerConfig( + {.name = "test.container", + .abbreviations = {"foo", "bar"}, + .aliases = { + {.alias = "foo", .qualified_name = "test.foo"}, + {.alias = "bar", .qualified_name = "test.bar"}, + }}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: + name: "test.container" + abbreviations: + - "bar" + - "foo" + aliases: + - alias: "bar" + qualified_name: "test.bar" + - alias: "foo" + qualified_name: "test.foo" + )yaml", + }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; @@ -1385,6 +1506,18 @@ std::vector GetRoundTripTestCases() { overloads: - id: "string_to_timestamp" )yaml", + R"yaml( + container: + name: "test.container" + abbreviations: + - "abbr1.Abbr1" + - "abbr2.Abbr2" + aliases: + - alias: "alias1" + qualified_name: "qual.name1" + - alias: "alias2" + qualified_name: "qual.name2" + )yaml", R"yaml( extensions: - name: "bindings" From cd9f059a5833c92576e85e3ffb2eaee2fd328e76 Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 12 May 2026 12:32:02 -0700 Subject: [PATCH 62/88] Fix repeated field null pruning for proto2/proto3 PiperOrigin-RevId: 914421409 --- common/values/struct_value_builder.cc | 11 +++++++++++ conformance/BUILD | 6 ------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/common/values/struct_value_builder.cc b/common/values/struct_value_builder.cc index 359596267..c342d6478 100644 --- a/common/values/struct_value_builder.cc +++ b/common/values/struct_value_builder.cc @@ -812,6 +812,17 @@ ProtoMessageRepeatedFieldFromValueMutator( const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + // If the value is null and the target repeated field is anything except + // google.protobuf.{Any,ListValue,Struct,Value}, it should be pruned. + if (value.IsNull()) { + const auto well_known_type = field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY) { + return absl::nullopt; + } + } auto* element = reflection->AddMessage(message, field, factory); auto result = ProtoMessageFromValueImpl(value, pool, factory, well_known_types, element); diff --git a/conformance/BUILD b/conformance/BUILD index 726a11b0b..abc0d918a 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -203,15 +203,9 @@ _TESTS_TO_SKIP = [ "conversions/string/double_hard", # Recent changes - "proto2/set_null/repeated_field_timestamp_null_pruned", - "proto2/set_null/repeated_field_duration_null_pruned", - "proto2/set_null/repeated_field_wrapper_null_pruned", "proto2/set_null/map_timestamp_null_pruned", "proto2/set_null/map_duration_null_pruned", "proto2/set_null/map_wrapper_null_pruned", - "proto3/set_null/repeated_field_timestamp_null_pruned", - "proto3/set_null/repeated_field_duration_null_pruned", - "proto3/set_null/repeated_field_wrapper_null_pruned", "proto3/set_null/map_timestamp_null_pruned", "proto3/set_null/map_duration_null_pruned", "proto3/set_null/map_wrapper_null_pruned", From 4749cf81003d9264fd87ca8b0640b5189bcc2b9e Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 12 May 2026 16:00:59 -0700 Subject: [PATCH 63/88] Fix scientific notation and fixed point formatting for int and uint PiperOrigin-RevId: 914525320 --- conformance/BUILD | 4 ---- extensions/formatting.cc | 6 ++++++ extensions/formatting_test.cc | 12 ++++++++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index abc0d918a..4f9232ab6 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -209,10 +209,6 @@ _TESTS_TO_SKIP = [ "proto3/set_null/map_timestamp_null_pruned", "proto3/set_null/map_duration_null_pruned", "proto3/set_null/map_wrapper_null_pruned", - "string_ext/format/default precision for fixed-point clause with int", - "string_ext/format/default precision for fixed-point clause with uint", - "string_ext/format/default precision for scientific notation with int", - "string_ext/format/default precision for scientific notation with uint", "namespace/namespace_shadowing/basic", "namespace/namespace_shadowing/comprehension_shadowing_namespaced_selector_disambiguation", ] diff --git a/extensions/formatting.cc b/extensions/formatting.cc index 935815569..252fdc7bd 100644 --- a/extensions/formatting.cc +++ b/extensions/formatting.cc @@ -419,6 +419,12 @@ absl::StatusOr GetDouble(const Value& value, std::string& scratch) { str)); } } + if (value.kind() == ValueKind::kInt) { + return static_cast(value.GetInt().NativeValue()); + } + if (value.kind() == ValueKind::kUint) { + return static_cast(value.GetUint().NativeValue()); + } if (value.kind() != ValueKind::kDouble) { return absl::InvalidArgumentError( absl::StrCat("expected a double but got a ", value.GetTypeName())); diff --git a/extensions/formatting_test.cc b/extensions/formatting_test.cc index b80fe9bc0..6a7fb300b 100644 --- a/extensions/formatting_test.cc +++ b/extensions/formatting_test.cc @@ -553,6 +553,18 @@ INSTANTIATE_TEST_SUITE_P( .format_args = "2.71828", .expected = "2.718280e+00", }, + { + .name = "FixedPointClauseWithInt", + .format = "%f", + .format_args = "3", + .expected = "3.000000", + }, + { + .name = "ScientificNotationWithUint", + .format = "%e", + .format_args = "uint(3)", + .expected = "3.000000e+00", + }, { .name = "NaNSupportForFixedPoint", .format = "%f", From 352666fba7822dd0d1f54dc00b332cc527aa81b1 Mon Sep 17 00:00:00 2001 From: Justin King Date: Wed, 13 May 2026 13:07:09 -0700 Subject: [PATCH 64/88] Fix map field value null pruning for proto2/proto3 PiperOrigin-RevId: 915015044 --- common/values/struct_value_builder.cc | 23 +++++++++++++++++++ conformance/BUILD | 6 ----- .../structs/proto_message_type_adapter.cc | 16 +++++++++++++ 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/common/values/struct_value_builder.cc b/common/values/struct_value_builder.cc index c342d6478..446b18421 100644 --- a/common/values/struct_value_builder.cc +++ b/common/values/struct_value_builder.cc @@ -956,6 +956,19 @@ class MessageValueBuilderImpl { if (error_value) { return false; } + if (map_value_field->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE && + entry_value.IsNull()) { + auto well_known_type = + map_value_field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + return true; + } + } google::protobuf::MapValueRef proto_value; extensions::protobuf_internal::InsertOrLookupMapValue( *reflection_, message_, *field, proto_key, &proto_value); @@ -989,6 +1002,16 @@ class MessageValueBuilderImpl { CEL_RETURN_IF_ERROR(list_value->ForEach( [this, field, accessor, &error_value](const Value& element) -> absl::StatusOr { + if (field->message_type() != nullptr && element.IsNull()) { + auto well_known_type = field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + return true; + } + } CEL_ASSIGN_OR_RETURN(error_value, (*accessor)(descriptor_pool_, message_factory_, &well_known_types_, reflection_, diff --git a/conformance/BUILD b/conformance/BUILD index 4f9232ab6..ccd2844c9 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -203,12 +203,6 @@ _TESTS_TO_SKIP = [ "conversions/string/double_hard", # Recent changes - "proto2/set_null/map_timestamp_null_pruned", - "proto2/set_null/map_duration_null_pruned", - "proto2/set_null/map_wrapper_null_pruned", - "proto3/set_null/map_timestamp_null_pruned", - "proto3/set_null/map_duration_null_pruned", - "proto3/set_null/map_wrapper_null_pruned", "namespace/namespace_shadowing/basic", "namespace/namespace_shadowing/comprehension_shadowing_namespaced_selector_disambiguation", ] diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index a351890c2..6a3417ba3 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -582,6 +582,19 @@ absl::Status ProtoMessageTypeAdapter::SetField( ValidateSetFieldOp(value_field_descriptor != nullptr, field->name(), "failed to find value field descriptor")); + bool prune_when_null = false; + if (value_field_descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + auto well_known_type = + value_field_descriptor->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + prune_when_null = true; + } + } + CEL_ASSIGN_OR_RETURN(const CelList* key_list, cel_map->ListKeys(arena)); for (int i = 0; i < key_list->size(); i++) { CelValue key = (*key_list).Get(arena, i); @@ -589,6 +602,9 @@ absl::Status ProtoMessageTypeAdapter::SetField( auto value = (*cel_map).Get(arena, key); CEL_RETURN_IF_ERROR(ValidateSetFieldOp(value.has_value(), field->name(), "error serializing CelMap")); + if (prune_when_null && value->IsNull()) { + continue; + } Message* entry_msg = message->GetReflection()->AddMessage(message, field); CEL_RETURN_IF_ERROR(internal::SetValueToSingleField( key, key_field_descriptor, entry_msg, arena)); From 037e0bb42339376640024de353451e372bb47820 Mon Sep 17 00:00:00 2001 From: Muhammad Askri Date: Wed, 13 May 2026 13:33:46 -0700 Subject: [PATCH 65/88] Adding a TypeSpec to Type resolver. PiperOrigin-RevId: 915029748 --- common/BUILD | 32 ++++ common/type_spec_resolver.cc | 182 +++++++++++++++++++++ common/type_spec_resolver.h | 37 +++++ common/type_spec_resolver_test.cc | 257 ++++++++++++++++++++++++++++++ 4 files changed, 508 insertions(+) create mode 100644 common/type_spec_resolver.cc create mode 100644 common/type_spec_resolver.h create mode 100644 common/type_spec_resolver_test.cc diff --git a/common/BUILD b/common/BUILD index 0ead8b15a..ffc4ae1e9 100644 --- a/common/BUILD +++ b/common/BUILD @@ -46,6 +46,38 @@ cc_test( ], ) +cc_library( + name = "type_spec_resolver", + srcs = ["type_spec_resolver.cc"], + hdrs = ["type_spec_resolver.h"], + deps = [ + ":ast", + ":type", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_spec_resolver_test", + srcs = ["type_spec_resolver_test.cc"], + deps = [ + ":ast", + ":type", + ":type_kind", + ":type_spec_resolver", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "expr", srcs = ["expr.cc"], diff --git a/common/type_spec_resolver.cc b/common/type_spec_resolver.cc new file mode 100644 index 000000000..97451f390 --- /dev/null +++ b/common/type_spec_resolver.cc @@ -0,0 +1,182 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_spec_resolver.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "internal/status_macros.h" + +namespace cel { + +absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, + google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool) { + if (type_spec.has_null()) return Type(NullType{}); + if (type_spec.has_dyn()) return Type(DynType{}); + + if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + return Type(BoolType{}); + case PrimitiveType::kInt64: + return Type(IntType{}); + case PrimitiveType::kUint64: + return Type(UintType{}); + case PrimitiveType::kDouble: + return Type(DoubleType{}); + case PrimitiveType::kString: + return Type(StringType{}); + case PrimitiveType::kBytes: + return Type(BytesType{}); + default: + return absl::InvalidArgumentError("Unsupported primitive type"); + } + } + + if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + return Type(AnyType{}); + case WellKnownTypeSpec::kTimestamp: + return Type(TimestampType{}); + case WellKnownTypeSpec::kDuration: + return Type(DurationType{}); + default: + return absl::InvalidArgumentError("Unsupported well-known type"); + } + } + + if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + return Type(BoolWrapperType{}); + case PrimitiveType::kInt64: + return Type(IntWrapperType{}); + case PrimitiveType::kUint64: + return Type(UintWrapperType{}); + case PrimitiveType::kDouble: + return Type(DoubleWrapperType{}); + case PrimitiveType::kString: + return Type(StringWrapperType{}); + case PrimitiveType::kBytes: + return Type(BytesWrapperType{}); + default: + return absl::InvalidArgumentError("Unsupported wrapper type"); + } + } + + if (type_spec.has_list_type()) { + CEL_ASSIGN_OR_RETURN( + auto elem_type, + ConvertTypeSpecToType(type_spec.list_type().elem_type(), arena, pool)); + return Type(ListType(arena, elem_type)); + } + + if (type_spec.has_map_type()) { + CEL_ASSIGN_OR_RETURN( + auto key_type, + ConvertTypeSpecToType(type_spec.map_type().key_type(), arena, pool)); + CEL_ASSIGN_OR_RETURN( + auto value_type, + ConvertTypeSpecToType(type_spec.map_type().value_type(), arena, pool)); + return Type(MapType(arena, key_type, value_type)); + } + + if (type_spec.has_function()) { + const auto& func_spec = type_spec.function(); + CEL_ASSIGN_OR_RETURN( + auto result_type, + ConvertTypeSpecToType(func_spec.result_type(), arena, pool)); + std::vector arg_types; + for (const auto& arg_spec : func_spec.arg_types()) { + CEL_ASSIGN_OR_RETURN(auto arg_type, + ConvertTypeSpecToType(arg_spec, arena, pool)); + arg_types.push_back(std::move(arg_type)); + } + return Type(FunctionType(arena, result_type, arg_types)); + } + + if (type_spec.has_type_param()) { + const std::string& name = type_spec.type_param().type(); + auto* allocated_name = google::protobuf::Arena::Create(arena, name); + return Type(TypeParamType(absl::string_view(*allocated_name))); + } + + if (type_spec.has_message_type()) { + const std::string& name = type_spec.message_type().type(); + const google::protobuf::Descriptor* descriptor = pool.FindMessageTypeByName(name); + if (descriptor == nullptr) { + return absl::InvalidArgumentError(absl::StrCat( + "Message type '", name, "' not found in descriptor pool")); + } + return Type::Message(descriptor); + } + + if (type_spec.has_abstract_type()) { + const std::string& name = type_spec.abstract_type().name(); + + // Check if it's a message type in the pool + const google::protobuf::Descriptor* descriptor = pool.FindMessageTypeByName(name); + if (descriptor != nullptr) { + if (!type_spec.abstract_type().parameter_types().empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Message type '", name, "' cannot have type parameters")); + } + return Type::Message(descriptor); + } + + // Check if it's an enum type in the pool + const google::protobuf::EnumDescriptor* enum_descriptor = + pool.FindEnumTypeByName(name); + if (enum_descriptor != nullptr) { + if (!type_spec.abstract_type().parameter_types().empty()) { + return absl::InvalidArgumentError( + absl::StrCat("Enum type '", name, "' cannot have type parameters")); + } + return Type::Enum(enum_descriptor); + } + + // Otherwise fallback to OpaqueType + std::vector params; + for (const auto& param_spec : type_spec.abstract_type().parameter_types()) { + CEL_ASSIGN_OR_RETURN(auto param, + ConvertTypeSpecToType(param_spec, arena, pool)); + params.push_back(std::move(param)); + } + auto* allocated_name = google::protobuf::Arena::Create(arena, name); + return Type(OpaqueType(arena, absl::string_view(*allocated_name), params)); + } + + if (type_spec.has_type()) { + CEL_ASSIGN_OR_RETURN(auto contained_type, + ConvertTypeSpecToType(type_spec.type(), arena, pool)); + return Type(TypeType(arena, contained_type)); + } + + if (type_spec.has_error()) { + return Type(ErrorType{}); + } + + return absl::InvalidArgumentError("Unknown TypeSpec kind"); +} + +} // namespace cel diff --git a/common/type_spec_resolver.h b/common/type_spec_resolver.h new file mode 100644 index 000000000..44e1e088f --- /dev/null +++ b/common/type_spec_resolver.h @@ -0,0 +1,37 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ + +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/type.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Resolves a `cel::TypeSpec` to a `cel::Type`. +// +// TypeSpec only specifies a type while Type provides support for inspecting +// properties of the type when used in CEL. Returns a status with code +// `InvalidArgument` if the input cannot be resolved to a type. +absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, + google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ diff --git a/common/type_spec_resolver_test.cc b/common/type_spec_resolver_test.cc new file mode 100644 index 000000000..c7fbb2cf8 --- /dev/null +++ b/common/type_spec_resolver_test.cc @@ -0,0 +1,257 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_spec_resolver.h" + +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::testing::HasSubstr; +using ::testing::TestWithParam; +using ::testing::Values; + +google::protobuf::Arena* GetTestArena() { + static absl::NoDestructor arena; + return &*arena; +} + +TEST(TypeSpecResolverTest, NullTypeSpec) { + TypeSpec spec(NullTypeSpec{}); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsNull()); +} + +TEST(TypeSpecResolverTest, DynTypeSpec) { + TypeSpec spec(DynTypeSpec{}); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsDyn()); +} + +using ConversionTest = testing::TestWithParam>; + +TEST_P(ConversionTest, TestTypeSpecConversion) { + ASSERT_OK_AND_ASSIGN( + auto t, ConvertTypeSpecToType(std::get<0>(GetParam()), GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_EQ(t.kind(), std::get<1>(GetParam())); +} + +INSTANTIATE_TEST_SUITE_P( + TypeSpecResolverTest, ConversionTest, + testing::Values( + std::make_tuple(TypeSpec(PrimitiveType::kBool), TypeKind::kBool), + std::make_tuple(TypeSpec(PrimitiveType::kInt64), TypeKind::kInt), + std::make_tuple(TypeSpec(PrimitiveType::kUint64), TypeKind::kUint), + std::make_tuple(TypeSpec(PrimitiveType::kDouble), TypeKind::kDouble), + std::make_tuple(TypeSpec(PrimitiveType::kString), TypeKind::kString), + std::make_tuple(TypeSpec(PrimitiveType::kBytes), TypeKind::kBytes), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kAny), TypeKind::kAny), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kTimestamp), + TypeKind::kTimestamp), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kDuration), + TypeKind::kDuration), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)), + TypeKind::kBoolWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)), + TypeKind::kIntWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)), + TypeKind::kUintWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + TypeKind::kDoubleWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)), + TypeKind::kStringWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)), + TypeKind::kBytesWrapper))); + +TEST(TypeSpecResolverTest, ListTypeConversion) { + auto elem = std::make_unique(PrimitiveType::kInt64); + TypeSpec spec(ListTypeSpec(std::move(elem))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsList()); + EXPECT_TRUE(t->GetList().element().IsInt()); +} + +TEST(TypeSpecResolverTest, MapTypeConversion) { + auto key = std::make_unique(PrimitiveType::kString); + auto val = std::make_unique(PrimitiveType::kBytes); + TypeSpec spec(MapTypeSpec(std::move(key), std::move(val))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMap()); + EXPECT_TRUE(t->GetMap().key().IsString()); + EXPECT_TRUE(t->GetMap().value().IsBytes()); +} + +TEST(TypeSpecResolverTest, FunctionTypeConversion) { + auto result = std::make_unique(PrimitiveType::kBool); + std::vector args; + args.push_back(TypeSpec(PrimitiveType::kString)); + TypeSpec spec(FunctionTypeSpec(std::move(result), std::move(args))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsFunction()); + EXPECT_EQ(t->GetFunction().args().size(), 1); + EXPECT_TRUE(t->GetFunction().result().IsBool()); +} + +TEST(TypeSpecResolverTest, TypeParamConversion) { + TypeSpec spec(ParamTypeSpec("T")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsTypeParam()); + EXPECT_EQ(t->GetTypeParam().name(), "T"); +} + +TEST(TypeSpecResolverTest, MessageTypeConversion) { + TypeSpec spec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", /*params=*/{})); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMessage()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(TypeSpecResolverTest, MessageTypeWithParamsError) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("cel.expr.conformance.proto3.TestAllTypes", + std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(TypeSpecResolverTest, UnresolvedAbstractTypeFallbackToOpaque) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("my.custom.OpaqueType", std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsOpaque()); + EXPECT_EQ(t->name(), "my.custom.OpaqueType"); + EXPECT_EQ(t->GetParameters().size(), 1); + EXPECT_TRUE(t->GetParameters()[0].IsInt()); +} + +TEST(TypeSpecResolverTest, OptionalType) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("optional_type", std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsOpaque()); + EXPECT_EQ(t->name(), "optional_type"); + EXPECT_EQ(t->GetParameters().size(), 1); + EXPECT_TRUE(t->GetParameters()[0].IsInt()); + EXPECT_TRUE(t->IsOptional()); +} + +TEST(TypeSpecResolverTest, TypeTypeConversion) { + auto nested = std::make_unique(PrimitiveType::kInt64); + TypeSpec spec(std::move(nested)); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsType()); + EXPECT_TRUE(t->GetType().GetType().IsInt()); +} + +TEST(TypeSpecResolverTest, ErrorTypeConversion) { + TypeSpec spec(ErrorTypeSpec::kValue); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsError()); +} + +TEST(TypeSpecResolverTest, MessageTypeSpecConversion) { + TypeSpec spec(MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMessage()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(TypeSpecResolverTest, MessageTypeSpecNotFoundError) { + TypeSpec spec(MessageTypeSpec("cel.expr.conformance.proto3.NonExistentType")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("not found in descriptor pool"))); +} + +TEST(TypeSpecResolverTest, EnumTypeConversion) { + TypeSpec spec(AbstractType( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum", /*params=*/{})); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsEnum()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"); +} + +TEST(TypeSpecResolverTest, EnumTypeWithParamsError) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes.NestedEnum", + std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(TypeSpecResolverTest, UnknownTypeSpecKindError) { + TypeSpec spec; + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unknown TypeSpec kind"))); +} + +} // namespace +} // namespace cel From ad18948079b2d3d8b9e62a202889076f872992e7 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 May 2026 22:06:31 -0700 Subject: [PATCH 66/88] No public description PiperOrigin-RevId: 915223448 --- eval/public/ast_rewrite.cc | 2 +- eval/public/ast_traverse.cc | 2 +- eval/public/cel_attribute.cc | 4 ++-- eval/public/equality_function_registrar_test.cc | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/eval/public/ast_rewrite.cc b/eval/public/ast_rewrite.cc index 3c210e607..87c667eb5 100644 --- a/eval/public/ast_rewrite.cc +++ b/eval/public/ast_rewrite.cc @@ -68,7 +68,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/eval/public/ast_traverse.cc b/eval/public/ast_traverse.cc index a86923c67..c18b806b9 100644 --- a/eval/public/ast_traverse.cc +++ b/eval/public/ast_traverse.cc @@ -67,7 +67,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/eval/public/cel_attribute.cc b/eval/public/cel_attribute.cc index 015289bed..70525a04d 100644 --- a/eval/public/cel_attribute.cc +++ b/eval/public/cel_attribute.cc @@ -76,8 +76,8 @@ CelAttributeQualifier CreateCelAttributeQualifier(const CelValue& value) { CelAttributePattern CreateCelAttributePattern( absl::string_view variable, - std::initializer_list> + std::initializer_list> path_spec) { std::vector path; path.reserve(path_spec.size()); diff --git a/eval/public/equality_function_registrar_test.cc b/eval/public/equality_function_registrar_test.cc index 772ddfeba..577c4be22 100644 --- a/eval/public/equality_function_registrar_test.cc +++ b/eval/public/equality_function_registrar_test.cc @@ -86,7 +86,7 @@ MATCHER_P2(DefinesHomogenousOverload, name, argument_type, struct EqualityTestCase { enum class ErrorKind { kMissingOverload, kMissingIdentifier }; absl::string_view expr; - absl::variant result; + std::variant result; CelValue lhs = CelValue::CreateNull(); CelValue rhs = CelValue::CreateNull(); }; From 1f5a7e62900ae2ad1021228df04f2a950744c001 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 May 2026 22:10:11 -0700 Subject: [PATCH 67/88] No public description PiperOrigin-RevId: 915224564 --- common/ast/constant_proto.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/ast/constant_proto.cc b/common/ast/constant_proto.cc index c0fe1c9f6..1982c05b4 100644 --- a/common/ast/constant_proto.cc +++ b/common/ast/constant_proto.cc @@ -35,7 +35,7 @@ using ConstantProto = cel::expr::Constant; absl::Status ConstantToProto(const Constant& constant, ConstantProto* absl_nonnull proto) { return absl::visit(absl::Overload( - [proto](absl::monostate) -> absl::Status { + [proto](std::monostate) -> absl::Status { proto->clear_constant_kind(); return absl::OkStatus(); }, From 6d311f704ade7aea062dd1091dfe3e683938fc78 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 May 2026 22:12:46 -0700 Subject: [PATCH 68/88] No public description PiperOrigin-RevId: 915225296 --- internal/json.cc | 2 +- internal/message_equality.cc | 8 ++++---- internal/well_known_types.cc | 2 +- internal/well_known_types_test.cc | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/json.cc b/internal/json.cc index 630ceb267..cdd4c1a5d 100644 --- a/internal/json.cc +++ b/internal/json.cc @@ -1417,7 +1417,7 @@ class JsonMapIterator final { } private: - absl::variant variant_; + std::variant variant_; }; class JsonAccessor { diff --git a/internal/message_equality.cc b/internal/message_equality.cc index 945cca8df..33ef78089 100644 --- a/internal/message_equality.cc +++ b/internal/message_equality.cc @@ -86,10 +86,10 @@ class EquatableMessage final }; using EquatableValue = - absl::variant; + std::variant; struct NullValueEqualer { bool operator()(std::nullptr_t, std::nullptr_t) const { return true; } diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index dee029534..02e50c3e3 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -2174,7 +2174,7 @@ absl::StatusOr AdaptFromMessage( if (adapted) { return adapted; } - return absl::monostate{}; + return std::monostate{}; } } diff --git a/internal/well_known_types_test.cc b/internal/well_known_types_test.cc index 0d2c9fe33..afc8ce396 100644 --- a/internal/well_known_types_test.cc +++ b/internal/well_known_types_test.cc @@ -806,7 +806,7 @@ TEST_F(AdaptFromMessageTest, Struct) { TEST_F(AdaptFromMessageTest, TestAllTypesProto3) { auto message = DynamicParseTextProto(R"pb()pb"); EXPECT_THAT(AdaptFromMessage(*message), - IsOkAndHolds(VariantWith(absl::monostate()))); + IsOkAndHolds(VariantWith(std::monostate()))); } TEST_F(AdaptFromMessageTest, Any_BoolValue) { From fb51dcdfd1082e67d209c1ba0c84e58b577c378a Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 00:13:04 -0700 Subject: [PATCH 69/88] No public description PiperOrigin-RevId: 915268033 --- runtime/internal/convert_constant.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/internal/convert_constant.cc b/runtime/internal/convert_constant.cc index a9effd229..33f382858 100644 --- a/runtime/internal/convert_constant.cc +++ b/runtime/internal/convert_constant.cc @@ -33,7 +33,7 @@ using ::cel::Constant; struct ConvertVisitor { Allocator<> allocator; - absl::StatusOr operator()(absl::monostate) { + absl::StatusOr operator()(std::monostate) { return absl::InvalidArgumentError("unspecified constant"); } absl::StatusOr operator()(std::nullptr_t) { return NullValue(); } From 877239571674284da3d22bcc7ccfe2e175643de7 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 00:15:52 -0700 Subject: [PATCH 70/88] No public description PiperOrigin-RevId: 915269088 --- common/values/message_value.cc | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/common/values/message_value.cc b/common/values/message_value.cc index e06206407..66dfd9511 100644 --- a/common/values/message_value.cc +++ b/common/values/message_value.cc @@ -46,7 +46,7 @@ const google::protobuf::Descriptor* absl_nonnull MessageValue::GetDescriptor() c ABSL_CHECK(*this); // Crash OK return absl::visit( absl::Overload( - [](absl::monostate) -> const google::protobuf::Descriptor* absl_nonnull { + [](std::monostate) -> const google::protobuf::Descriptor* absl_nonnull { ABSL_UNREACHABLE(); }, [](const ParsedMessageValue& alternative) @@ -58,7 +58,7 @@ const google::protobuf::Descriptor* absl_nonnull MessageValue::GetDescriptor() c std::string MessageValue::DebugString() const { return absl::visit( - absl::Overload([](absl::monostate) -> std::string { return "INVALID"; }, + absl::Overload([](std::monostate) -> std::string { return "INVALID"; }, [](const ParsedMessageValue& alternative) -> std::string { return alternative.DebugString(); }), @@ -68,7 +68,7 @@ std::string MessageValue::DebugString() const { bool MessageValue::IsZeroValue() const { ABSL_DCHECK(*this); return absl::visit( - absl::Overload([](absl::monostate) -> bool { return true; }, + absl::Overload([](std::monostate) -> bool { return true; }, [](const ParsedMessageValue& alternative) -> bool { return alternative.IsZeroValue(); }), @@ -81,7 +81,7 @@ absl::Status MessageValue::SerializeTo( google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJson` on " "an invalid `MessageValue`"); @@ -99,7 +99,7 @@ absl::Status MessageValue::ConvertToJson( google::protobuf::Message* absl_nonnull json) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJson` on " "an invalid `MessageValue`"); @@ -117,7 +117,7 @@ absl::Status MessageValue::ConvertToJsonObject( google::protobuf::Message* absl_nonnull json) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJsonObject` on " "an invalid `MessageValue`"); @@ -136,7 +136,7 @@ absl::Status MessageValue::Equal( google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `Equal` on " "an invalid `MessageValue`"); @@ -155,7 +155,7 @@ absl::Status MessageValue::GetFieldByName( google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `GetFieldByName` on " "an invalid `MessageValue`"); @@ -175,7 +175,7 @@ absl::Status MessageValue::GetFieldByNumber( google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `GetFieldByNumber` on " "an invalid `MessageValue`"); @@ -192,7 +192,7 @@ absl::StatusOr MessageValue::HasFieldByName( absl::string_view name) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::StatusOr { + [](std::monostate) -> absl::StatusOr { return absl::InternalError( "unexpected attempt to invoke `HasFieldByName` on " "an invalid `MessageValue`"); @@ -206,7 +206,7 @@ absl::StatusOr MessageValue::HasFieldByName( absl::StatusOr MessageValue::HasFieldByNumber(int64_t number) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::StatusOr { + [](std::monostate) -> absl::StatusOr { return absl::InternalError( "unexpected attempt to invoke `HasFieldByNumber` on " "an invalid `MessageValue`"); @@ -224,7 +224,7 @@ absl::Status MessageValue::ForEachField( google::protobuf::Arena* absl_nonnull arena) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ForEachField` on " "an invalid `MessageValue`"); @@ -244,7 +244,7 @@ absl::Status MessageValue::Qualify( int* absl_nonnull count) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `Qualify` on " "an invalid `MessageValue`"); From ff45a7c2a096ed1d38e6ed4d80a7180be32874b7 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 03:39:55 -0700 Subject: [PATCH 71/88] No public description PiperOrigin-RevId: 915340723 --- common/ast_rewrite.cc | 2 +- common/ast_traverse.cc | 2 +- common/decl_proto.cc | 2 +- common/decl_proto_test.cc | 4 ++-- common/decl_proto_v1alpha1.cc | 2 +- common/type.cc | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/common/ast_rewrite.cc b/common/ast_rewrite.cc index 14582f44f..b61e1fab6 100644 --- a/common/ast_rewrite.cc +++ b/common/ast_rewrite.cc @@ -54,7 +54,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/common/ast_traverse.cc b/common/ast_traverse.cc index a6ba0d1ba..fb4f9731e 100644 --- a/common/ast_traverse.cc +++ b/common/ast_traverse.cc @@ -53,7 +53,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/common/decl_proto.cc b/common/decl_proto.cc index 89f7f4453..098c5068c 100644 --- a/common/decl_proto.cc +++ b/common/decl_proto.cc @@ -69,7 +69,7 @@ absl::StatusOr FunctionDeclFromProto( return decl; } -absl::StatusOr> DeclFromProto( +absl::StatusOr> DeclFromProto( const cel::expr::Decl& decl, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { diff --git a/common/decl_proto_test.cc b/common/decl_proto_test.cc index 62215f07f..d72d97e09 100644 --- a/common/decl_proto_test.cc +++ b/common/decl_proto_test.cc @@ -49,7 +49,7 @@ TEST_P(DeclFromProtoTest, FromProtoWorks) { cel::expr::Decl decl_pb; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(test_case.proto_decl, &decl_pb)); - absl::StatusOr> decl_or = + absl::StatusOr> decl_or = DeclFromProto(decl_pb, descriptor_pool, &arena); switch (test_case.decl_type) { case DeclType::kVariable: { @@ -79,7 +79,7 @@ TEST_P(DeclFromProtoTest, FromV1Alpha1ProtoWorks) { google::api::expr::v1alpha1::Decl decl_pb; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(test_case.proto_decl, &decl_pb)); - absl::StatusOr> decl_or = + absl::StatusOr> decl_or = DeclFromV1Alpha1Proto(decl_pb, descriptor_pool, &arena); switch (test_case.decl_type) { case DeclType::kVariable: { diff --git a/common/decl_proto_v1alpha1.cc b/common/decl_proto_v1alpha1.cc index 2c6cfb6e4..a8d73e5c2 100644 --- a/common/decl_proto_v1alpha1.cc +++ b/common/decl_proto_v1alpha1.cc @@ -52,7 +52,7 @@ absl::StatusOr FunctionDeclFromV1Alpha1Proto( return FunctionDeclFromProto(name, unversioned, descriptor_pool, arena); } -absl::StatusOr> DeclFromV1Alpha1Proto( +absl::StatusOr> DeclFromV1Alpha1Proto( const google::api::expr::v1alpha1::Decl& decl, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { diff --git a/common/type.cc b/common/type.cc index ce8c7a89a..f94e8bc52 100644 --- a/common/type.cc +++ b/common/type.cc @@ -97,7 +97,7 @@ static constexpr std::array kTypeToKindArray = { TypeKind::kUnknown}; static_assert(kTypeToKindArray.size() == - absl::variant_size(), + std::variant_size(), "Kind indexer must match variant declaration for cel::Type."); } // namespace From 366498bd3820ab8382282ca15279753e4789be31 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 03:46:00 -0700 Subject: [PATCH 72/88] No public description PiperOrigin-RevId: 915342975 --- common/types/struct_type.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/common/types/struct_type.cc b/common/types/struct_type.cc index 4540cec9c..a1be1f786 100644 --- a/common/types/struct_type.cc +++ b/common/types/struct_type.cc @@ -27,7 +27,7 @@ namespace cel { absl::string_view StructType::name() const { ABSL_DCHECK(*this); return absl::visit( - absl::Overload([](absl::monostate) { return absl::string_view(); }, + absl::Overload([](std::monostate) { return absl::string_view(); }, [](const common_internal::BasicStructType& alt) { return alt.name(); }, @@ -39,7 +39,7 @@ TypeParameters StructType::GetParameters() const { ABSL_DCHECK(*this); return absl::visit( absl::Overload( - [](absl::monostate) { return TypeParameters(); }, + [](std::monostate) { return TypeParameters(); }, [](const common_internal::BasicStructType& alt) { return alt.GetParameters(); }, @@ -49,7 +49,7 @@ TypeParameters StructType::GetParameters() const { std::string StructType::DebugString() const { return absl::visit( - absl::Overload([](absl::monostate) { return std::string(); }, + absl::Overload([](std::monostate) { return std::string(); }, [](common_internal::BasicStructType alt) { return alt.DebugString(); }, @@ -72,7 +72,7 @@ MessageType StructType::GetMessage() const { common_internal::TypeVariant StructType::ToTypeVariant() const { return absl::visit( absl::Overload( - [](absl::monostate) { return common_internal::TypeVariant(); }, + [](std::monostate) { return common_internal::TypeVariant(); }, [](common_internal::BasicStructType alt) { return static_cast(alt) ? common_internal::TypeVariant(alt) : common_internal::TypeVariant(); From 33156b1e59b458ff6c24208dbcb66ace3186ab9b Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 04:28:29 -0700 Subject: [PATCH 73/88] No public description PiperOrigin-RevId: 915359111 --- extensions/select_optimization.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc index 0f09773ae..44da4c48a 100644 --- a/extensions/select_optimization.cc +++ b/extensions/select_optimization.cc @@ -92,7 +92,7 @@ struct SelectInstruction { // Represents a single qualifier in a traversal path. // TODO(uncreated-issue/51): support variable indexes. using QualifierInstruction = - absl::variant; + std::variant; struct SelectPath { Expr* operand; From a88afaca5106943d9f835cb622be9813b6bdee55 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 05:07:40 -0700 Subject: [PATCH 74/88] No public description PiperOrigin-RevId: 915372200 --- runtime/memory_safety_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/memory_safety_test.cc b/runtime/memory_safety_test.cc index 7e864ecf6..2a09be666 100644 --- a/runtime/memory_safety_test.cc +++ b/runtime/memory_safety_test.cc @@ -73,7 +73,7 @@ struct TestCase { std::string name; std::string expression; absl::flat_hash_map> + std::variant> activation; test::ValueMatcher expected_matcher; bool reference_resolver_enabled = false; From 513af3c2c338c0aabbbef21419018880dc9c23c4 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 22:02:12 -0700 Subject: [PATCH 75/88] No public description PiperOrigin-RevId: 915787513 --- eval/internal/cel_value_equal_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval/internal/cel_value_equal_test.cc b/eval/internal/cel_value_equal_test.cc index f52f38916..109a63795 100644 --- a/eval/internal/cel_value_equal_test.cc +++ b/eval/internal/cel_value_equal_test.cc @@ -67,7 +67,7 @@ using ::testing::ValuesIn; struct EqualityTestCase { enum class ErrorKind { kMissingOverload, kMissingIdentifier }; absl::string_view expr; - absl::variant result; + std::variant result; CelValue lhs = CelValue::CreateNull(); CelValue rhs = CelValue::CreateNull(); }; From f62419d04f3a4c12ecf2a802e95d26e33aa2b115 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 22:02:24 -0700 Subject: [PATCH 76/88] No public description PiperOrigin-RevId: 915787600 --- tools/branch_coverage.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/branch_coverage.cc b/tools/branch_coverage.cc index 00ab7cb5a..b5bba3ffe 100644 --- a/tools/branch_coverage.cc +++ b/tools/branch_coverage.cc @@ -71,7 +71,7 @@ struct OtherNode { // Representation for coverage of an AST node. struct CoverageNode { int evaluate_count; - absl::variant kind; + std::variant kind; }; const Type* absl_nullable FindCheckerType(const CheckedExpr& expr, From 7820913cda14e09bbb667c10d65391c1a79fb95d Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 22:16:10 -0700 Subject: [PATCH 77/88] No public description PiperOrigin-RevId: 915792480 --- common/value.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/common/value.cc b/common/value.cc index 535ddead8..1cd3f54e1 100644 --- a/common/value.cc +++ b/common/value.cc @@ -115,7 +115,7 @@ Type Value::GetRuntimeType() const { namespace { template -struct IsMonostate : std::is_same, absl::monostate> {}; +struct IsMonostate : std::is_same, std::monostate> {}; } // namespace @@ -171,7 +171,7 @@ absl::Status Value::ConvertToJsonArray( google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); return variant_.Visit(absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError("use of invalid Value"); }, [descriptor_pool, message_factory, json]( @@ -212,7 +212,7 @@ absl::Status Value::ConvertToJsonObject( google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); return variant_.Visit(absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError("use of invalid Value"); }, [descriptor_pool, message_factory, json]( @@ -1363,7 +1363,7 @@ Value Value::FromMessage( return absl::visit( absl::Overload(OwningWellKnownTypesValueVisitor{ /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { auto* cloned = message.New(arena); cloned->CopyFrom(message); return ParsedMessageValue(cloned, arena); @@ -1391,7 +1391,7 @@ Value Value::FromMessage( return absl::visit( absl::Overload(OwningWellKnownTypesValueVisitor{ /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { auto* cloned = message.New(arena); cloned->GetReflection()->Swap(cloned, &message); return ParsedMessageValue(cloned, arena); @@ -1422,7 +1422,7 @@ Value Value::WrapMessage( absl::Overload(BorrowingWellKnownTypesValueVisitor{ /* .message = */ message, /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { if (message->GetArena() != arena) { auto* cloned = message->New(arena); cloned->CopyFrom(*message); @@ -1456,7 +1456,7 @@ Value Value::WrapMessageUnsafe( absl::Overload(BorrowingWellKnownTypesValueVisitor{ /* .message = */ message, /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { if (message->GetArena() != arena) { return UnsafeParsedMessageValue(message); } From da45d34071c8fe9f77fcc17e33e518841d382cdc Mon Sep 17 00:00:00 2001 From: Antoine Pietri Date: Mon, 18 May 2026 08:43:19 -0700 Subject: [PATCH 78/88] Add missing include for `google/rpc/status.proto.h`. This code was relying on the transitive inclusion of third_party/cel/cpp/* to provide the type information for the Status proto. This makes the code brittle and prone to breakages when doing internal header refactors. PiperOrigin-RevId: 917253701 --- conformance/BUILD | 1 + conformance/service.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/conformance/BUILD b/conformance/BUILD index ccd2844c9..0ca90a4bc 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -83,6 +83,7 @@ cc_library( "@com_google_googleapis//google/api/expr/conformance/v1alpha1:conformance_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/rpc:code_cc_proto", + "@com_google_googleapis//google/rpc:status_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:empty_cc_proto", "@com_google_protobuf//:protobuf", diff --git a/conformance/service.cc b/conformance/service.cc index 463334bb5..7e3eded82 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -30,6 +30,7 @@ #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/rpc/code.pb.h" +#include "google/rpc/status.pb.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/status/status.h" From d475cc6726ef85fefed557c8eb0e400119d13e95 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 20 May 2026 20:28:14 -0700 Subject: [PATCH 79/88] No public description PiperOrigin-RevId: 918791547 --- codelab/network_functions.cc | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/codelab/network_functions.cc b/codelab/network_functions.cc index f4f729827..64f199cb3 100644 --- a/codelab/network_functions.cc +++ b/codelab/network_functions.cc @@ -213,8 +213,7 @@ absl::Status NetworkAddressRepEqual( return absl::OkStatus(); } const NetworkAddressRep rep = content.To(); - absl::optional other_rep = - NetworkAddressRep::Unwrap(other); + std::optional other_rep = NetworkAddressRep::Unwrap(other); ABSL_DCHECK(other_rep.has_value()); *result = cel::BoolValue(rep.IsEqualTo(*other_rep)); return absl::OkStatus(); @@ -311,7 +310,7 @@ cel::Value parseAddress( google::protobuf::Arena* absl_nonnull arena) { std::string buf; absl::string_view addr = str.ToStringView(&buf); - absl::optional rep = NetworkAddressRep::Parse(addr); + std::optional rep = NetworkAddressRep::Parse(addr); if (!rep.has_value()) { return cel::ErrorValue(absl::InvalidArgumentError("invalid address")); } @@ -321,7 +320,7 @@ cel::Value parseAddress( cel::Value parseAddressOrZero(const cel::StringValue& str) { std::string buf; absl::string_view addr = str.ToStringView(&buf); - absl::optional rep = NetworkAddressRep::Parse(addr); + std::optional rep = NetworkAddressRep::Parse(addr); static const NetworkAddressRep kZero; if (!rep.has_value()) { return NetworkAddressRep::MakeValue(kZero); @@ -336,8 +335,7 @@ cel::Value parseAddressMatcher( google::protobuf::Arena* absl_nonnull arena) { std::string buf; absl::string_view addr = str.ToStringView(&buf); - absl::optional rep = - NetworkAddressMatcher::Parse(addr); + std::optional rep = NetworkAddressMatcher::Parse(addr); if (!rep.has_value()) { return cel::ErrorValue( absl::InvalidArgumentError("invalid address matcher")); @@ -365,7 +363,7 @@ cel::Value NetworkAddressRep::MakeValue(const NetworkAddressRep& rep) { cel::OpaqueValueContent::From(rep)); } -absl::optional NetworkAddressRep::Unwrap( +std::optional NetworkAddressRep::Unwrap( const cel::Value& value) { auto opaque = value.AsOpaque(); if (!opaque.has_value() || @@ -381,7 +379,7 @@ absl::optional NetworkAddressRep::Unwrap( return opaque->content().To(); } -absl::optional NetworkAddressRep::Parse( +std::optional NetworkAddressRep::Parse( absl::string_view str) { uint32_t ipv4 = 0; char ipv6[16]; @@ -418,7 +416,7 @@ bool NetworkAddressRep::IsLessThan(const NetworkAddressRep& other) const { return false; } -absl::optional NetworkAddressMatcher::Parse( +std::optional NetworkAddressMatcher::Parse( absl::string_view str) { // range style addr-addr int dash_pos = str.find('-'); From b7096df80e0d7b6facc2943326a1c04cde0f1d27 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 20 May 2026 20:28:15 -0700 Subject: [PATCH 80/88] No public description PiperOrigin-RevId: 918791557 --- eval/public/equality_function_registrar_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eval/public/equality_function_registrar_test.cc b/eval/public/equality_function_registrar_test.cc index 577c4be22..a77a92734 100644 --- a/eval/public/equality_function_registrar_test.cc +++ b/eval/public/equality_function_registrar_test.cc @@ -204,7 +204,7 @@ std::string CelValueEqualTestName( } TEST_P(CelValueEqualImplTypesTest, Basic) { - absl::optional result = CelValueEqualImpl(lhs(), rhs()); + std::optional result = CelValueEqualImpl(lhs(), rhs()); if (lhs().IsNull() || rhs().IsNull()) { if (lhs().IsNull() && rhs().IsNull()) { @@ -286,7 +286,7 @@ const std::vector& NumericValuesNotEqualExample() { using NumericInequalityTest = testing::TestWithParam; TEST_P(NumericInequalityTest, NumericValues) { NumericInequalityTestCase test_case = GetParam(); - absl::optional result = CelValueEqualImpl(test_case.a, test_case.b); + std::optional result = CelValueEqualImpl(test_case.a, test_case.b); EXPECT_TRUE(result.has_value()); EXPECT_EQ(*result, false); } @@ -299,7 +299,7 @@ INSTANTIATE_TEST_SUITE_P( }); TEST(CelValueEqualImplTest, LossyNumericEquality) { - absl::optional result = CelValueEqualImpl( + std::optional result = CelValueEqualImpl( CelValue::CreateDouble( static_cast(std::numeric_limits::max()) - 1), CelValue::CreateInt64(std::numeric_limits::max())); From 719f3eed5919bc964b30c7e06a77a3a7eeb64953 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 20 May 2026 21:47:09 -0700 Subject: [PATCH 81/88] No public description PiperOrigin-RevId: 918818689 --- eval/tests/benchmark_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval/tests/benchmark_test.cc b/eval/tests/benchmark_test.cc index fc0c39294..f188dc0b7 100644 --- a/eval/tests/benchmark_test.cc +++ b/eval/tests/benchmark_test.cc @@ -317,7 +317,7 @@ BENCHMARK(BM_PolicySymbolic); class RequestMap : public CelMap { public: - absl::optional operator[](CelValue key) const override { + std::optional operator[](CelValue key) const override { if (!key.IsString()) { return {}; } From a552526a3c58346438cec05cbfe3afeb20657ed6 Mon Sep 17 00:00:00 2001 From: Muhammad Askri Date: Thu, 21 May 2026 12:28:31 -0700 Subject: [PATCH 82/88] Add functions to parse type and function signatures into cel types. PiperOrigin-RevId: 919194450 --- common/internal/BUILD | 8 +- common/internal/signature.cc | 390 +++++++++++++++++++++++- common/internal/signature.h | 21 ++ common/internal/signature_test.cc | 489 ++++++++++++++++++++++++++++-- 4 files changed, 889 insertions(+), 19 deletions(-) diff --git a/common/internal/BUILD b/common/internal/BUILD index 10084b685..48a8dfe8b 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -143,14 +143,17 @@ cc_library( srcs = ["signature.cc"], hdrs = ["signature.h"], deps = [ + "//common:ast", "//common:type", "//common:type_kind", + "//common:type_spec_resolver", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", ], ) @@ -159,11 +162,14 @@ cc_test( srcs = ["signature_test.cc"], deps = [ ":signature", + "//common:ast", "//common:type", + "//common:type_kind", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], diff --git a/common/internal/signature.cc b/common/internal/signature.cc index f63049878..5c75225f9 100644 --- a/common/internal/signature.cc +++ b/common/internal/signature.cc @@ -15,20 +15,30 @@ #include "common/internal/signature.h" #include +#include +#include #include #include +#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/types/optional.h" +#include "common/ast.h" #include "common/type.h" #include "common/type_kind.h" +#include "common/type_spec_resolver.h" #include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel::common_internal { +// Signature generator helper functions. namespace { void AppendEscaped(std::string* result, std::string_view str, bool escape_dot) { @@ -58,7 +68,7 @@ absl::Status AppendTypeParameters(std::string* result, const Type& type); // Recursively appends a string representation of the given `type` to `result`. // Type parameters are enclosed in angle brackets and separated by commas. - +// // Grammar: // TypeDesc = NamespaceIdentifier [ "<" TypeList ">" ] ; // NamespaceIdentifier = [ "." ] Identifier { "." Identifier } ; @@ -208,4 +218,382 @@ absl::StatusOr MakeOverloadSignature( return result; } + +// Signature parser helper functions. +namespace { + +std::string StripUnescapedWhitespace(std::string_view str) { + std::string result; + result.reserve(str.size()); + bool escaped = false; + for (char c : str) { + if (escaped) { + result.push_back(c); + escaped = false; + continue; + } + if (c == '\\') { + result.push_back(c); + escaped = true; + continue; + } + if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { + continue; + } + result.push_back(c); + } + return result; +} + +absl::optional ParseBuiltinOrWrapper(std::string_view name_str) { + if (name_str == "null") return TypeSpec(NullTypeSpec()); + if (name_str == "bool") return TypeSpec(PrimitiveType::kBool); + if (name_str == "int") return TypeSpec(PrimitiveType::kInt64); + if (name_str == "uint") return TypeSpec(PrimitiveType::kUint64); + if (name_str == "double") return TypeSpec(PrimitiveType::kDouble); + if (name_str == "string") return TypeSpec(PrimitiveType::kString); + if (name_str == "bytes") return TypeSpec(PrimitiveType::kBytes); + if (name_str == "any" || name_str == "google.protobuf.Any") + return TypeSpec(WellKnownTypeSpec::kAny); + if (name_str == "timestamp" || name_str == "google.protobuf.Timestamp") + return TypeSpec(WellKnownTypeSpec::kTimestamp); + if (name_str == "duration" || name_str == "google.protobuf.Duration") + return TypeSpec(WellKnownTypeSpec::kDuration); + if (name_str == "dyn" || name_str == "google.protobuf.Value") + return TypeSpec(DynTypeSpec()); + + // Handle standard Protobuf well-known wrapper types to preserve + // backward compatibility for users migrating YAML configuration files. + if (name_str == "bool_wrapper" || name_str == "google.protobuf.BoolValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + if (name_str == "int_wrapper" || name_str == "google.protobuf.Int64Value" || + name_str == "google.protobuf.Int32Value") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + if (name_str == "uint_wrapper" || name_str == "google.protobuf.UInt64Value" || + name_str == "google.protobuf.UInt32Value") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + if (name_str == "double_wrapper" || + name_str == "google.protobuf.DoubleValue" || + name_str == "google.protobuf.FloatValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + if (name_str == "string_wrapper" || name_str == "google.protobuf.StringValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + if (name_str == "bytes_wrapper" || name_str == "google.protobuf.BytesValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + + if (name_str == "google.protobuf.ListValue") { + return TypeSpec(ListTypeSpec(std::make_unique(DynTypeSpec()))); + } + if (name_str == "google.protobuf.Struct") { + return TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec()))); + } + + return absl::nullopt; +} + +std::string Unescape(std::string_view str) { + size_t first_escape = str.find('\\'); + if (first_escape == std::string_view::npos) { + return std::string(str); + } + std::string result; + result.reserve(str.size()); + result.append(str.substr(0, first_escape)); + bool escaped = false; + for (size_t i = first_escape; i < str.size(); ++i) { + char c = str[i]; + if (escaped) { + result.push_back(c); + escaped = false; + } else if (c == '\\') { + escaped = true; + } else { + result.push_back(c); + } + } + if (escaped) { + result.push_back('\\'); + } + return result; +} + +class SignatureScanner { + public: + explicit SignatureScanner(std::string_view input, + std::string_view error_prefix = "Invalid signature") + : input_(input), error_prefix_(error_prefix) {} + + absl::StatusOr FindTopLevelChar(char target, bool find_last = false) { + size_t found_idx = std::string_view::npos; + int nesting = 0; + bool escaped = false; + // Scanning str for delimiter boundaries while ensuring + // brackets are balanced and escape backslashes are bypassed. + for (size_t i = 0; i < input_.size(); ++i) { + char c = input_[i]; + if (escaped) { + escaped = false; + continue; + } + if (c == '\\') { + escaped = true; + continue; + } + if (c == target && nesting == 0) { + if (find_last || found_idx == std::string_view::npos) { + found_idx = i; + } + } + if (c == '<') { + nesting++; + } else if (c == '>') { + nesting--; + if (nesting < 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + } + } + if (nesting != 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + return found_idx; + } + + absl::StatusOr> SplitTopLevel(char delimiter) { + std::vector result; + int nesting = 0; + bool escaped = false; + size_t start = 0; + // Scanning str for delimiter while ensuring brackets are balanced and + // escape backslashes are bypassed. + for (size_t i = 0; i < input_.size(); ++i) { + char c = input_[i]; + if (escaped) { + escaped = false; + continue; + } + if (c == '\\') { + escaped = true; + continue; + } + if (c == delimiter && nesting == 0) { + result.push_back(input_.substr(start, i - start)); + start = i + 1; + } + if (c == '<') { + nesting++; + } else if (c == '>') { + nesting--; + if (nesting < 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + } + } + if (nesting != 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + result.push_back(input_.substr(start)); + return result; + } + + private: + std::string_view input_; + std::string_view error_prefix_; +}; + +absl::StatusOr> SplitTypeList( + std::string_view params) { + return SignatureScanner(params, "Invalid type signature").SplitTopLevel(','); +} + +absl::StatusOr ParseTypeSignature(std::string_view signature) { + if (signature.empty()) { + return absl::InvalidArgumentError("Empty type signature"); + } + + if (signature[0] == '~') { + std::string_view param_name = signature.substr(1); + if (param_name.empty()) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid type parameter name"); + } + CEL_ASSIGN_OR_RETURN(size_t less_idx, + SignatureScanner(param_name) + .FindTopLevelChar('<', /*find_last=*/false)); + CEL_ASSIGN_OR_RETURN(size_t comma_idx, + SignatureScanner(param_name) + .FindTopLevelChar(',', /*find_last=*/false)); + if (less_idx != std::string_view::npos || + comma_idx != std::string_view::npos) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid type parameter name"); + } + return TypeSpec(ParamTypeSpec(Unescape(param_name))); + } + + CEL_ASSIGN_OR_RETURN(size_t less_idx, + SignatureScanner(signature, "Invalid type signature") + .FindTopLevelChar('<', /*find_last=*/false)); + + std::string name_str; + std::vector params; + + if (less_idx != std::string_view::npos) { + // If the signature contains a '<', it must also contain a matching '>'. + if (signature.back() != '>') { + return absl::InvalidArgumentError( + "Invalid type signature: missing closing >"); + } + name_str = Unescape(signature.substr(0, less_idx)); + std::string_view params_str = + signature.substr(less_idx + 1, signature.size() - less_idx - 2); + CEL_ASSIGN_OR_RETURN(auto param_list, SplitTypeList(params_str)); + for (std::string_view param_str : param_list) { + CEL_ASSIGN_OR_RETURN(auto param, ParseTypeSignature(param_str)); + params.push_back(std::move(param)); + } + } else { + name_str = Unescape(signature); + } + + auto read_param_or_dyn = [¶ms](size_t index) { + auto spec = std::make_unique(DynTypeSpec()); + if (params.size() > index) { + *spec = std::move(params[index]); + } + return spec; + }; + + if (!params.empty()) { + if (ParseBuiltinOrWrapper(name_str).has_value()) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid type signature: ", name_str, + " cannot have type parameters")); + } + } else { + if (auto builtin = ParseBuiltinOrWrapper(name_str); builtin.has_value()) { + return *builtin; + } + } + + if (name_str == "type") { + if (params.size() > 1) { + return absl::InvalidArgumentError( + "Invalid type signature: type expects at most 1 parameter"); + } + return TypeSpec(read_param_or_dyn(0)); + } + + if (name_str == "list") { + if (params.size() > 1) { + return absl::InvalidArgumentError( + "Invalid type signature: list expects at most 1 parameter"); + } + return TypeSpec(ListTypeSpec(read_param_or_dyn(0))); + } + + if (name_str == "map") { + if (!params.empty() && params.size() != 2) { + return absl::InvalidArgumentError( + "Invalid type signature: map expects 0 or 2 parameters"); + } + auto key = read_param_or_dyn(0); + auto value = read_param_or_dyn(1); + return TypeSpec(MapTypeSpec(std::move(key), std::move(value))); + } + + if (name_str == "function") { + auto result_type = read_param_or_dyn(0); + std::vector arg_types; + for (size_t i = 1; i < params.size(); ++i) { + arg_types.push_back(std::move(params[i])); + } + return TypeSpec( + FunctionTypeSpec(std::move(result_type), std::move(arg_types))); + } + + if (name_str.empty() || absl::StrContains(name_str, "..")) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid identifier"); + } + + return TypeSpec(AbstractType(name_str, std::move(params))); +} + +} // namespace + +absl::StatusOr ParseFunctionSignature( + std::string_view signature) { + std::string stripped_sig = StripUnescapedWhitespace(signature); + if (stripped_sig.empty()) { + return absl::InvalidArgumentError("Empty function signature"); + } + + CEL_ASSIGN_OR_RETURN( + size_t paren_idx, + SignatureScanner(stripped_sig, "Invalid function signature") + .FindTopLevelChar('(', /*find_last=*/false)); + + if (paren_idx == std::string_view::npos || stripped_sig.back() != ')') { + return absl::InvalidArgumentError("Invalid function signature"); + } + + std::string_view prefix = std::string_view(stripped_sig).substr(0, paren_idx); + std::string_view args_str = + std::string_view(stripped_sig) + .substr(paren_idx + 1, stripped_sig.size() - paren_idx - 2); + + std::vector arg_types; + ParsedFunctionOverload out; + + CEL_ASSIGN_OR_RETURN(size_t dot_idx, + SignatureScanner(prefix, "Invalid function signature") + .FindTopLevelChar('.', /*find_last=*/true)); + + if (dot_idx != std::string_view::npos) { + out.is_member = true; + std::string_view receiver_str = prefix.substr(0, dot_idx); + std::string_view func_str = prefix.substr(dot_idx + 1); + + CEL_ASSIGN_OR_RETURN(auto receiver_param, ParseTypeSignature(receiver_str)); + arg_types.push_back(std::move(receiver_param)); + out.function_name = Unescape(func_str); + } else { + out.is_member = false; + out.function_name = Unescape(prefix); + } + + if (out.function_name.empty()) { + return absl::InvalidArgumentError( + "Invalid function signature: empty function name"); + } + + if (!args_str.empty()) { + CEL_ASSIGN_OR_RETURN(auto arg_list, SplitTypeList(args_str)); + for (std::string_view arg_str : arg_list) { + CEL_ASSIGN_OR_RETURN(auto arg_param, ParseTypeSignature(arg_str)); + arg_types.push_back(std::move(arg_param)); + } + } + + auto result_type = std::make_unique(DynTypeSpec()); + out.signature_type = + TypeSpec(FunctionTypeSpec(std::move(result_type), std::move(arg_types))); + + return out; +} + +absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool) { + std::string stripped_sig = StripUnescapedWhitespace(signature); + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSignature(stripped_sig)); + return cel::ConvertTypeSpecToType(type_spec, arena, pool); +} + } // namespace cel::common_internal diff --git a/common/internal/signature.h b/common/internal/signature.h index 3f31d8fd1..3fdba4b2e 100644 --- a/common/internal/signature.h +++ b/common/internal/signature.h @@ -20,7 +20,10 @@ #include #include "absl/status/statusor.h" +#include "common/ast.h" #include "common/type.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel::common_internal { @@ -56,6 +59,24 @@ absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member); +// Parses a string type signature directly into a `cel::Type`. +absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool); + +// A parsed function overload signature with the function name, flag for member +// function, and the function signature type. +struct ParsedFunctionOverload { + std::string function_name; + bool is_member = false; + // The function signature type, configured as a `FunctionTypeSpec`. + TypeSpec signature_type; +}; + +// Parses a string function overload signature directly into a +// `cel::TypeSpec` configured as a `FunctionTypeSpec`. +absl::StatusOr ParseFunctionSignature( + std::string_view signature); + } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ diff --git a/common/internal/signature_test.cc b/common/internal/signature_test.cc index 8e41c70fb..765055f75 100644 --- a/common/internal/signature_test.cc +++ b/common/internal/signature_test.cc @@ -13,13 +13,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include "absl/base/no_destructor.h" #include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "common/ast.h" #include "common/type.h" +#include "common/type_kind.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" @@ -38,6 +42,101 @@ google::protobuf::Arena* GetTestArena() { return &*arena; } +void VerifyParsedMatchesType(const TypeSpec& parsed, const Type& original) { + switch (original.kind()) { + case TypeKind::kDyn: + EXPECT_TRUE(parsed.has_dyn()); + break; + case TypeKind::kNull: + EXPECT_TRUE(parsed.has_null()); + break; + case TypeKind::kBool: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kBool); + break; + case TypeKind::kInt: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kInt64); + break; + case TypeKind::kUint: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kUint64); + break; + case TypeKind::kDouble: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kDouble); + break; + case TypeKind::kString: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kString); + break; + case TypeKind::kBytes: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kBytes); + break; + case TypeKind::kAny: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kAny); + break; + case TypeKind::kTimestamp: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kTimestamp); + break; + case TypeKind::kDuration: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kDuration); + break; + case TypeKind::kList: + EXPECT_TRUE(parsed.has_list_type()); + if (!original.GetParameters().empty()) { + VerifyParsedMatchesType(parsed.list_type().elem_type(), + original.GetParameters()[0]); + } + break; + case TypeKind::kMap: + EXPECT_TRUE(parsed.has_map_type()); + if (!original.GetParameters().empty()) { + VerifyParsedMatchesType(parsed.map_type().key_type(), + original.GetParameters()[0]); + } + if (original.GetParameters().size() > 1) { + VerifyParsedMatchesType(parsed.map_type().value_type(), + original.GetParameters()[1]); + } + break; + case TypeKind::kBoolWrapper: + case TypeKind::kIntWrapper: + case TypeKind::kUintWrapper: + case TypeKind::kDoubleWrapper: + case TypeKind::kStringWrapper: + case TypeKind::kBytesWrapper: + EXPECT_TRUE(parsed.has_wrapper()); + break; + case TypeKind::kType: + EXPECT_TRUE(parsed.has_type()); + if (!original.GetParameters().empty()) { + VerifyParsedMatchesType(parsed.type(), original.GetParameters()[0]); + } + break; + case TypeKind::kTypeParam: + EXPECT_TRUE(parsed.has_type_param()); + break; + default: + EXPECT_TRUE(parsed.has_abstract_type()); + break; + } +} + +void VerifyTypesEqual(const Type& lhs, const Type& rhs) { + EXPECT_EQ(lhs.kind(), rhs.kind()); + if (lhs.kind() != rhs.kind()) return; + + if (lhs.kind() == TypeKind::kOpaque || lhs.kind() == TypeKind::kStruct || + lhs.kind() == TypeKind::kTypeParam) { + EXPECT_EQ(lhs.name(), rhs.name()); + } + + const auto& lhs_params = lhs.GetParameters(); + const auto& rhs_params = rhs.GetParameters(); + EXPECT_EQ(lhs_params.size(), rhs_params.size()); + if (lhs_params.size() == rhs_params.size()) { + for (size_t i = 0; i < lhs_params.size(); ++i) { + VerifyTypesEqual(lhs_params[i], rhs_params[i]); + } + } +} + struct TypeSignatureTestCase { Type type; std::string expected_signature; @@ -73,10 +172,18 @@ std::vector GetTypeSignatureTestCases() { .type = ListType(GetTestArena(), StringType{}), .expected_signature = "list", }, + { + .type = TypeType(GetTestArena(), IntType{}), + .expected_signature = "type", + }, { .type = ListType(GetTestArena(), TypeParamType("A")), .expected_signature = "list<~A>", }, + { + .type = ListType(GetTestArena(), TypeParamType("A GetTypeSignatureTestCases() { .expected_signature = "map<~B,~C>", }, { - .type = OpaqueType( - GetTestArena(), "bar", - {FunctionType(GetTestArena(), TypeParamType("D"), {})}), - .expected_signature = "bar>", + .type = OpaqueType(GetTestArena(), "bar", + {FunctionType(GetTestArena(), TypeParamType("D"), + {StringType{}, BoolType{}})}), + .expected_signature = "bar>", }, { .type = AnyType{}, @@ -104,10 +211,18 @@ std::vector GetTypeSignatureTestCases() { .type = TimestampType{}, .expected_signature = "timestamp", }, + { + .type = BoolWrapperType{}, + .expected_signature = "bool_wrapper", + }, { .type = IntWrapperType{}, .expected_signature = "int_wrapper", }, + { + .type = UintWrapperType{}, + .expected_signature = "uint_wrapper", + }, { .type = MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")), @@ -117,22 +232,32 @@ std::vector GetTypeSignatureTestCases() { .type = ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)")), .expected_signature = R"(list<~a\,b\.\\.\(d\)\\e>)", }, - { - .type = UnknownType{}, - .expected_error = - "Type kind: *unknown* is not supported in CEL declarations", - }, - { - .type = ErrorType{}, - .expected_error = - "Type kind: *error* is not supported in CEL declarations", - }, }; } +TEST(TypeSignatureTest, UnsupportedTypes) { + EXPECT_THAT(common_internal::MakeTypeSignature(UnknownType{}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Type kind: *unknown* is not supported"))); + + EXPECT_THAT(common_internal::MakeTypeSignature(ErrorType{}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Type kind: *error* is not supported"))); +} + INSTANTIATE_TEST_SUITE_P(TypeIdTest, TypeSignatureTest, ValuesIn(GetTypeSignatureTestCases())); +TEST_P(TypeSignatureTest, ParseTypeCheck) { + const auto& param = GetParam(); + if (!param.expected_signature.empty() && param.expected_error.empty()) { + auto parsed = ParseType(param.expected_signature, GetTestArena(), + *GetTestingDescriptorPool()); + ASSERT_THAT(parsed, ::absl_testing::IsOk()); + VerifyTypesEqual(*parsed, param.type); + } +} + struct OverloadSignatureTestCase { std::string function_name = "hello"; std::vector args; @@ -202,10 +327,18 @@ std::vector GetOverloadSignatureTestCases() { .args = {TimestampType{}}, .expected_signature = "hello(timestamp)", }, + { + .args = {BoolWrapperType{}}, + .expected_signature = "hello(bool_wrapper)", + }, { .args = {IntWrapperType{}}, .expected_signature = "hello(int_wrapper)", }, + { + .args = {UintWrapperType{}}, + .expected_signature = "hello(uint_wrapper)", + }, { .args = {MessageType( GetTestingDescriptorPool()->FindMessageTypeByName( @@ -213,9 +346,6 @@ std::vector GetOverloadSignatureTestCases() { .expected_signature = "hello(cel.expr.conformance.proto3.TestAllTypes)", }, - {.args = {}, - .is_member = true, - .expected_error = "Member function with no receiver"}, { .args = {StringType{}}, .is_member = true, @@ -231,6 +361,18 @@ std::vector GetOverloadSignatureTestCases() { .is_member = true, .expected_signature = "string.hello(bool,dyn)", }, + { + .function_name = "hello", + .args = {OpaqueType(GetTestArena(), "bar", + {TypeParamType("dummy.type")})}, + .is_member = true, + .expected_signature = R"(bar<~dummy\.type>.hello())", + }, + { + .function_name = "inspect", + .args = {Type(TypeType(GetTestArena(), StringType{}))}, + .expected_signature = "inspect(type)", + }, { .function_name = R"(h.(e),l\o)", .args = {StringType{}, @@ -242,8 +384,321 @@ std::vector GetOverloadSignatureTestCases() { }; } +TEST(OverloadSignatureTest, MemberFunctionNoReceiverError) { + auto signature = common_internal::MakeOverloadSignature("hello", {}, true); + EXPECT_THAT(signature, + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Member function with no receiver"))); +} + INSTANTIATE_TEST_SUITE_P(OverloadIdTest, OverloadSignatureTest, ValuesIn(GetOverloadSignatureTestCases())); +TEST_P(OverloadSignatureTest, ExhaustiveFunctionParseCheck) { + const auto& param = GetParam(); + if (!param.expected_signature.empty()) { + auto parsed = ParseFunctionSignature(param.expected_signature); + ASSERT_THAT(parsed, ::absl_testing::IsOk()); + EXPECT_EQ(parsed->function_name, param.function_name); + EXPECT_EQ(parsed->is_member, param.is_member); + EXPECT_TRUE(parsed->signature_type.has_function()); + const auto& func = parsed->signature_type.function(); + for (size_t i = 0; i < param.args.size(); ++i) { + VerifyParsedMatchesType(func.arg_types()[i], param.args[i]); + } + } +} + +TEST(ParseSignatureTest, ProtoParsing) { + ASSERT_OK_AND_ASSIGN( + auto t1, ParseType("int", GetTestArena(), *GetTestingDescriptorPool())); + EXPECT_TRUE(t1.IsInt()); + + ASSERT_OK_AND_ASSIGN(auto t2, ParseType("list<~A>", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(t2.IsList()); + + ASSERT_OK_AND_ASSIGN(auto t3, ParseType(R"(~abc\)", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(t3.IsTypeParam()); + EXPECT_EQ(t3.GetTypeParam().name(), R"(abc\)"); + + ASSERT_OK_AND_ASSIGN(auto w1, + ParseType("google.protobuf.BoolValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w1.IsBoolWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w2, + ParseType("google.protobuf.Int64Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w2.IsIntWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w3, + ParseType("google.protobuf.Int32Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w3.IsIntWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w4, + ParseType("google.protobuf.UInt64Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w4.IsUintWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w5, + ParseType("google.protobuf.UInt32Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w5.IsUintWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w6, + ParseType("google.protobuf.DoubleValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w6.IsDoubleWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w7, + ParseType("google.protobuf.FloatValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w7.IsDoubleWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w8, + ParseType("google.protobuf.StringValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w8.IsStringWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w9, + ParseType("google.protobuf.BytesValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w9.IsBytesWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w10, ParseType("string_wrapper", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w10.IsStringWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w11, ParseType("bytes_wrapper", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w11.IsBytesWrapper()); + + ASSERT_OK_AND_ASSIGN(auto gp_any, + ParseType("google.protobuf.Any", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_any.IsAny()); + + ASSERT_OK_AND_ASSIGN(auto gp_timestamp, + ParseType("google.protobuf.Timestamp", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_timestamp.IsTimestamp()); + + ASSERT_OK_AND_ASSIGN(auto gp_duration, + ParseType("google.protobuf.Duration", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_duration.IsDuration()); + + ASSERT_OK_AND_ASSIGN(auto gp_value, + ParseType("google.protobuf.Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_value.IsDyn()); + + ASSERT_OK_AND_ASSIGN(auto gp_list_value, + ParseType("google.protobuf.ListValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_list_value.IsList()); + + ASSERT_OK_AND_ASSIGN(auto gp_struct, + ParseType("google.protobuf.Struct", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_struct.IsMap()); + + // Legal whitespace handling tests + ASSERT_OK_AND_ASSIGN(auto ws_type1, + ParseType("map < int , string > ", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(ws_type1.IsMap()); + + ASSERT_OK_AND_ASSIGN(auto ws_type2, + ParseType("map\t<\nint\r,\tstring\n>\r", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(ws_type2.IsMap()); +} + +TEST(ParseSignatureTest, FunctionParsing) { + ASSERT_OK_AND_ASSIGN(auto f1, ParseFunctionSignature("hello(string)")); + EXPECT_TRUE(f1.signature_type.has_function()); + EXPECT_EQ(f1.signature_type.function().arg_types().size(), 1); + + // Legal whitespace handling tests + ASSERT_OK_AND_ASSIGN(auto ws_func1, + ParseFunctionSignature(" hello ( string ) ")); + EXPECT_TRUE(ws_func1.signature_type.has_function()); + EXPECT_EQ(ws_func1.signature_type.function().arg_types().size(), 1); + + ASSERT_OK_AND_ASSIGN(auto ws_func2, + ParseFunctionSignature("\thello\n(\rstring\t)\n\r")); + EXPECT_TRUE(ws_func2.signature_type.has_function()); + EXPECT_EQ(ws_func2.signature_type.function().arg_types().size(), 1); + + ASSERT_OK_AND_ASSIGN(auto f2, ParseFunctionSignature("a.b.c()")); + EXPECT_TRUE(f2.is_member); + EXPECT_EQ(f2.function_name, "c"); +} + +TEST(ParseSignatureTest, ParsingErrors) { + // Mismatched template brackets and parentheses. + EXPECT_THAT( + ParseType("list>", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT( + ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseType("list><", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("hello(list>)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("hello(list)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("foo", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("list expects at most 1 parameter"))); + EXPECT_THAT( + ParseType("map", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map expects 0 or 2 parameters"))); + EXPECT_THAT(ParseType("map", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map expects 0 or 2 parameters"))); + + // Enforcing valid function and identifier names. + EXPECT_THAT(ParseFunctionSignature("()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("empty function name"))); + EXPECT_THAT(ParseFunctionSignature("string.()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("empty function name"))); + + // Missing closing operators and boundary checks. + EXPECT_THAT( + ParseType("listfoo", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("missing closing >"))); + + EXPECT_THAT(ParseFunctionSignature("hello>(string)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT( + ParseType("list<", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("map", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("map int, string>", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("list", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + + EXPECT_THAT(ParseFunctionSignature("a..b.c()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + EXPECT_THAT( + ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); + + EXPECT_THAT( + ParseType("~list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + + // Checks that builtin types cannot have type parameters. + EXPECT_THAT( + ParseType("int", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(ParseSignatureTest, MessageTypeWithParamsError) { + EXPECT_THAT(ParseType("cel.expr.conformance.proto3.TestAllTypes", + GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(ParseSignatureTest, MissingClosingParenthesisError) { + EXPECT_THAT(ParseFunctionSignature("hello(string"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); + EXPECT_THAT(ParseFunctionSignature("hello)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); +} + +TEST(ParseSignatureTest, NestedDotsNonMember) { + auto f1 = ParseFunctionSignature( + "my_opaque()"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_FALSE(f1->is_member); + EXPECT_EQ(f1->function_name, + "my_opaque"); +} + +TEST(ParseSignatureTest, OverlyComplexSignatures) { + auto t1 = ParseType("map>,map>>", + GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t1, ::absl_testing::IsOk()); + EXPECT_TRUE(t1->IsMap()); + + auto t2 = ParseType(R"(~abc\\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t2, ::absl_testing::IsOk()); + EXPECT_TRUE(t2->IsTypeParam()); + EXPECT_EQ(t2->GetTypeParam().name(), R"(abc\)"); + + auto t3 = + ParseType(R"(~abc\\\\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t3, ::absl_testing::IsOk()); + EXPECT_TRUE(t3->IsTypeParam()); + EXPECT_EQ(t3->GetTypeParam().name(), R"(abc\\)"); + + auto f1 = ParseFunctionSignature( + "bar>,map>.func(string)"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_TRUE(f1->is_member); + EXPECT_EQ(f1->function_name, "func"); + EXPECT_TRUE(f1->signature_type.has_function()); + EXPECT_EQ(f1->signature_type.function().arg_types().size(), 2); +} + +TEST(ParseSignatureTest, EmptyOrWhitespaceErrors) { + EXPECT_THAT(ParseType("", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); + EXPECT_THAT(ParseFunctionSignature(""), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty function signature"))); + EXPECT_THAT(ParseType("list>", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); +} + } // namespace } // namespace cel::common_internal From 6fd7030b954562c5e5c1c1185066e80cad29fd25 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 21 May 2026 18:17:46 -0700 Subject: [PATCH 83/88] not yet exported PiperOrigin-RevId: 919358445 --- common/expr_factory.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/expr_factory.h b/common/expr_factory.h index 773217ad9..5607d8deb 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -32,6 +32,7 @@ namespace cel { class MacroExprFactory; class ParserMacroExprFactory; +class OptimizerExprFactory; class ExprFactory { protected: @@ -378,6 +379,7 @@ class ExprFactory { private: friend class MacroExprFactory; friend class ParserMacroExprFactory; + friend class OptimizerExprFactory; ExprFactory() : accu_var_(kAccumulatorVariableName) {} From ec82288de1338c6d7763fd722d52c3636965ca1e Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 21 May 2026 20:41:46 -0700 Subject: [PATCH 84/88] No public description PiperOrigin-RevId: 919408845 --- .../descriptor_pool_type_introspector.cc | 12 +++--- .../descriptor_pool_type_introspector_test.cc | 4 +- checker/internal/type_check_env.cc | 10 ++--- checker/internal/type_checker_builder_impl.cc | 6 +-- .../type_checker_builder_impl_test.cc | 2 +- checker/internal/type_checker_impl.cc | 20 +++++----- checker/internal/type_inference_context.cc | 8 ++-- .../internal/type_inference_context_test.cc | 40 +++++++++---------- 8 files changed, 51 insertions(+), 51 deletions(-) diff --git a/checker/internal/descriptor_pool_type_introspector.cc b/checker/internal/descriptor_pool_type_introspector.cc index f6001e947..da4f4430b 100644 --- a/checker/internal/descriptor_pool_type_introspector.cc +++ b/checker/internal/descriptor_pool_type_introspector.cc @@ -35,7 +35,7 @@ namespace { // Standard implementation for field lookups. // Avoids building a FieldTable and just checks the DescriptorPool directly. -absl::StatusOr> +absl::StatusOr> FindStructTypeFieldByNameDirectly( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, absl::string_view type, absl::string_view name) { @@ -60,7 +60,7 @@ FindStructTypeFieldByNameDirectly( // Standard implementation for listing fields. // Avoids building a FieldTable and just checks the DescriptorPool directly. absl::StatusOr< - absl::optional>> + std::optional>> ListStructTypeFieldsDirectly( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, absl::string_view type) { @@ -88,7 +88,7 @@ ListStructTypeFieldsDirectly( using Field = DescriptorPoolTypeIntrospector::Field; -absl::StatusOr> +absl::StatusOr> DescriptorPoolTypeIntrospector::FindTypeImpl(absl::string_view name) const { const google::protobuf::Descriptor* absl_nullable descriptor = descriptor_pool_->FindMessageTypeByName(name); @@ -103,7 +103,7 @@ DescriptorPoolTypeIntrospector::FindTypeImpl(absl::string_view name) const { return absl::nullopt; } -absl::StatusOr> +absl::StatusOr> DescriptorPoolTypeIntrospector::FindEnumConstantImpl( absl::string_view type, absl::string_view value) const { const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = @@ -124,7 +124,7 @@ DescriptorPoolTypeIntrospector::FindEnumConstantImpl( return absl::nullopt; } -absl::StatusOr> +absl::StatusOr> DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const { if (!use_json_name_) { @@ -151,7 +151,7 @@ DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( } absl::StatusOr< - absl::optional>> + std::optional>> DescriptorPoolTypeIntrospector::ListFieldsForStructTypeImpl( absl::string_view type) const { if (!use_json_name_) { diff --git a/checker/internal/descriptor_pool_type_introspector_test.cc b/checker/internal/descriptor_pool_type_introspector_test.cc index e2fdc9d40..456798744 100644 --- a/checker/internal/descriptor_pool_type_introspector_test.cc +++ b/checker/internal/descriptor_pool_type_introspector_test.cc @@ -117,7 +117,7 @@ TEST(DescriptorPoolTypeIntrospectorTest, internal::GetTestingDescriptorPool()); introspector.set_use_json_name(true); - absl::StatusOr> field = + absl::StatusOr> field = introspector.FindStructTypeFieldByName( "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); @@ -132,7 +132,7 @@ TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructType) { DescriptorPoolTypeIntrospector introspector( internal::GetTestingDescriptorPool()); absl::StatusOr< - absl::optional>> + std::optional>> fields = introspector.ListFieldsForStructType( "cel.expr.conformance.proto3.TestAllTypes"); ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(260)))); diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index c080326cb..763d9ba46 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -48,7 +48,7 @@ const FunctionDecl* absl_nullable TypeCheckEnv::LookupFunction( return nullptr; } -absl::StatusOr> TypeCheckEnv::LookupTypeName( +absl::StatusOr> TypeCheckEnv::LookupTypeName( absl::string_view name) const { for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { @@ -60,7 +60,7 @@ absl::StatusOr> TypeCheckEnv::LookupTypeName( return absl::nullopt; } -absl::StatusOr> TypeCheckEnv::LookupEnumConstant( +absl::StatusOr> TypeCheckEnv::LookupEnumConstant( absl::string_view type, absl::string_view value) const { for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { @@ -77,9 +77,9 @@ absl::StatusOr> TypeCheckEnv::LookupEnumConstant( return absl::nullopt; } -absl::StatusOr> TypeCheckEnv::LookupTypeConstant( +absl::StatusOr> TypeCheckEnv::LookupTypeConstant( google::protobuf::Arena* absl_nonnull arena, absl::string_view name) const { - CEL_ASSIGN_OR_RETURN(absl::optional type, LookupTypeName(name)); + CEL_ASSIGN_OR_RETURN(std::optional type, LookupTypeName(name)); if (type.has_value()) { return MakeVariableDecl(type->name(), TypeType(arena, *type)); } @@ -94,7 +94,7 @@ absl::StatusOr> TypeCheckEnv::LookupTypeConstant( return absl::nullopt; } -absl::StatusOr> TypeCheckEnv::LookupStructField( +absl::StatusOr> TypeCheckEnv::LookupStructField( absl::string_view type_name, absl::string_view field_name) const { // Check the type providers in registration order. // Note: this doesn't allow for shadowing a type with a subset type of the diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 94a05602e..85b581e83 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -158,8 +158,8 @@ absl::StatusOr MergeFunctionDecls( return merged_decl; } -absl::optional FilterDecl(FunctionDecl decl, - const TypeCheckerSubset& subset) { +std::optional FilterDecl(FunctionDecl decl, + const TypeCheckerSubset& subset) { FunctionDecl filtered; std::string name = decl.release_name(); std::vector overloads = decl.release_overloads(); @@ -283,7 +283,7 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( for (FunctionDeclRecord& fn : config.functions) { FunctionDecl decl = std::move(fn.decl); if (subset != nullptr) { - absl::optional filtered = + std::optional filtered = FilterDecl(std::move(decl), *subset); if (!filtered.has_value()) { continue; diff --git a/checker/internal/type_checker_builder_impl_test.cc b/checker/internal/type_checker_builder_impl_test.cc index f7a3dff97..494e7e440 100644 --- a/checker/internal/type_checker_builder_impl_test.cc +++ b/checker/internal/type_checker_builder_impl_test.cc @@ -144,7 +144,7 @@ TEST(ContextDeclsTest, CustomStructNotSupported) { {}); class MyTypeProvider : public cel::TypeIntrospector { public: - absl::StatusOr> FindTypeImpl( + absl::StatusOr> FindTypeImpl( absl::string_view name) const override { if (name == "com.example.MyStruct") { return common_internal::MakeBasicStructType("com.example.MyStruct"); diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 2472d7def..1ce871255 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -379,7 +379,7 @@ class ResolveVisitor : public AstVisitorBase { // Lookup message type by name to support WellKnownType creation. CEL_ASSIGN_OR_RETURN( - absl::optional field_info, + std::optional field_info, env_->LookupStructField(resolved_name, field.name())); if (!field_info.has_value()) { ReportUndefinedField(field.id(), field.name(), resolved_name); @@ -405,8 +405,8 @@ class ResolveVisitor : public AstVisitorBase { return absl::OkStatus(); } - absl::optional CheckFieldType(int64_t expr_id, const Type& operand_type, - absl::string_view field_name); + std::optional CheckFieldType(int64_t expr_id, const Type& operand_type, + absl::string_view field_name); void HandleOptSelect(const Expr& expr); void HandleBlockIndex(const Expr* expr); @@ -919,7 +919,7 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, arg_types.push_back(GetDeducedType(&expr.call_expr().args()[i])); } - absl::optional resolution = + std::optional resolution = inference_context_->ResolveOverload(decl, arg_types, is_receiver); if (!resolution.has_value()) { @@ -968,7 +968,7 @@ const VariableDecl* absl_nullable ResolveVisitor::LookupGlobalIdentifier( if (const VariableDecl* decl = env_->LookupVariable(name); decl != nullptr) { return decl; } - absl::StatusOr> constant = + absl::StatusOr> constant = env_->LookupTypeConstant(arena_, name); if (!constant.ok()) { @@ -1079,9 +1079,9 @@ void ResolveVisitor::ResolveQualifiedIdentifier( } } -absl::optional ResolveVisitor::CheckFieldType(int64_t id, - const Type& operand_type, - absl::string_view field) { +std::optional ResolveVisitor::CheckFieldType(int64_t id, + const Type& operand_type, + absl::string_view field) { if (operand_type.kind() == TypeKind::kDyn || operand_type.kind() == TypeKind::kAny) { return DynType(); @@ -1137,7 +1137,7 @@ void ResolveVisitor::ResolveSelectOperation(const Expr& expr, const Expr& operand) { const Type& operand_type = GetDeducedType(&operand); - absl::optional result_type; + std::optional result_type; int64_t id = expr.id(); // Support short-hand optional chaining. if (operand_type.IsOptional()) { @@ -1184,7 +1184,7 @@ void ResolveVisitor::HandleOptSelect(const Expr& expr) { operand_type = operand_type.GetOptional().GetParameter(); } - absl::optional field_type = CheckFieldType( + std::optional field_type = CheckFieldType( expr.id(), operand_type, field->const_expr().string_value()); if (!field_type.has_value()) { types_[&expr] = ErrorType(); diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 96d985071..5b909d982 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -133,7 +133,7 @@ FunctionOverloadInstance InstantiateFunctionOverload( // Converts a wrapper type to its corresponding primitive type. // Returns nullopt if the type is not a wrapper type. -absl::optional WrapperToPrimitive(const Type& t) { +std::optional WrapperToPrimitive(const Type& t) { switch (t.kind()) { case TypeKind::kBoolWrapper: return BoolType(); @@ -286,7 +286,7 @@ bool TypeInferenceContext::IsAssignableInternal( } // Type is as concrete as it can be under current substitutions. - if (absl::optional wrapped_type = WrapperToPrimitive(to_subs); + if (std::optional wrapped_type = WrapperToPrimitive(to_subs); wrapped_type.has_value()) { return from_subs.IsNull() || IsAssignableInternal(*wrapped_type, from_subs, @@ -531,11 +531,11 @@ bool TypeInferenceContext::IsAssignableWithConstraints( return false; } -absl::optional +std::optional TypeInferenceContext::ResolveOverload(const FunctionDecl& decl, absl::Span argument_types, bool is_receiver) { - absl::optional result_type; + std::optional result_type; std::vector matching_overloads; for (const auto& ovl : decl.overloads()) { diff --git a/checker/internal/type_inference_context_test.cc b/checker/internal/type_inference_context_test.cc index d1bf7fa6d..458d08ff1 100644 --- a/checker/internal/type_inference_context_test.cc +++ b/checker/internal/type_inference_context_test.cc @@ -291,7 +291,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadBasic) { MakeOverloadDecl("add_double", DoubleType(), DoubleType(), DoubleType()))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), IntType()}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_THAT(resolution->result_type, IsTypeKind(TypeKind::kInt)); @@ -309,7 +309,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadFails) { MakeOverloadDecl("add_double", DoubleType(), DoubleType(), DoubleType()))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), DoubleType()}, false); ASSERT_FALSE(resolution.has_value()); } @@ -324,7 +324,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithParamsNoMatch) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), DoubleType()}, false); ASSERT_FALSE(resolution.has_value()); } @@ -341,7 +341,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithMixedParamsMatch) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_a, list_of_a}, false); ASSERT_TRUE(resolution.has_value()) << context.DebugString(); } @@ -359,7 +359,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithMixedParamsMatch2) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_a, list_of_int}, false); ASSERT_TRUE(resolution.has_value()) << context.DebugString(); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("equals"))); @@ -375,7 +375,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithParamsMatches) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), IntType()}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_TRUE(resolution->result_type.IsBool()); @@ -394,7 +394,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsMatch) { Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); - absl::optional resolution = + std::optional resolution = context.ResolveOverload( decl, {list_of_a_instance, ListType(&arena, IntType())}, false); ASSERT_TRUE(resolution.has_value()); @@ -407,7 +407,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsMatch) { EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("add_list"))); - absl::optional resolution2 = + std::optional resolution2 = context.ResolveOverload( decl, {ListType(&arena, IntType()), list_of_a_instance}, false); ASSERT_TRUE(resolution2.has_value()); @@ -433,7 +433,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsNoMatch) { Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_a_instance, IntType()}, false); EXPECT_FALSE(resolution.has_value()); } @@ -450,13 +450,13 @@ TEST(TypeInferenceContextTest, InferencesAccumulate) { Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); - absl::optional resolution1 = + std::optional resolution1 = context.ResolveOverload(decl, {list_of_a_instance, list_of_a_instance}, false); ASSERT_TRUE(resolution1.has_value()); EXPECT_TRUE(resolution1->result_type.IsList()); - absl::optional resolution2 = + std::optional resolution2 = context.ResolveOverload( decl, {resolution1->result_type, ListType(&arena, IntType())}, false); ASSERT_TRUE(resolution2.has_value()); @@ -480,7 +480,7 @@ TEST(TypeInferenceContextTest, DebugString) { MakeFunctionDecl("_+_", MakeOverloadDecl("add_list", list_of_a, list_of_a, list_of_a))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_int, list_of_int}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_TRUE(resolution->result_type.IsList()); @@ -517,7 +517,7 @@ class TypeInferenceContextWrapperTypesTest TEST_P(TypeInferenceContextWrapperTypesTest, ResolvePrimitiveArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapper_type, test_case.wrapped_primitive_type}, @@ -534,7 +534,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, ResolvePrimitiveArg) { TEST_P(TypeInferenceContextWrapperTypesTest, ResolveWrapperArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload( ternary_decl_, {BoolType(), test_case.wrapper_type, test_case.wrapper_type}, false); @@ -550,7 +550,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, ResolveWrapperArg) { TEST_P(TypeInferenceContextWrapperTypesTest, ResolveNullArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapper_type, NullType()}, false); @@ -566,7 +566,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, ResolveNullArg) { TEST_P(TypeInferenceContextWrapperTypesTest, NullWidens) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), NullType(), test_case.wrapper_type}, false); @@ -582,7 +582,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, NullWidens) { TEST_P(TypeInferenceContextWrapperTypesTest, PrimitiveWidens) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapped_primitive_type, test_case.wrapper_type}, @@ -622,7 +622,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithUnionTypePromotion) { /*result_type=*/TypeParamType("A"), BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {BoolType(), NullType(), IntWrapperType()}, false); ASSERT_TRUE(resolution.has_value()); @@ -648,7 +648,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithTypeType) { TypeType(&arena, TypeParamType("A")), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {StringType()}, false); ASSERT_TRUE(resolution.has_value()); @@ -680,7 +680,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithInferredTypeType) { BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(to_type_decl, {StringType()}, false); ASSERT_TRUE(resolution.has_value()); From 833bb0c8fe93dc9bf5e2971c38254f80d3a42c1f Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 21 May 2026 20:46:26 -0700 Subject: [PATCH 85/88] No public description PiperOrigin-RevId: 919410392 --- testutil/test_macros.cc | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/testutil/test_macros.cc b/testutil/test_macros.cc index 158135762..672439dc5 100644 --- a/testutil/test_macros.cc +++ b/testutil/test_macros.cc @@ -37,9 +37,8 @@ bool IsCelNamespace(const Expr& target) { return target.has_ident_expr() && target.ident_expr().name() == "cel"; } -absl::optional CelBlockMacroExpander(MacroExprFactory& factory, - Expr& target, - absl::Span args) { +std::optional CelBlockMacroExpander(MacroExprFactory& factory, + Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } @@ -51,9 +50,8 @@ absl::optional CelBlockMacroExpander(MacroExprFactory& factory, return factory.NewCall("cel.@block", args); } -absl::optional CelIndexMacroExpander(MacroExprFactory& factory, - Expr& target, - absl::Span args) { +std::optional CelIndexMacroExpander(MacroExprFactory& factory, + Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } @@ -70,9 +68,9 @@ absl::optional CelIndexMacroExpander(MacroExprFactory& factory, return factory.NewIdent(absl::StrCat("@index", index)); } -absl::optional CelIterVarMacroExpander(MacroExprFactory& factory, - Expr& target, - absl::Span args) { +std::optional CelIterVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } @@ -94,9 +92,9 @@ absl::optional CelIterVarMacroExpander(MacroExprFactory& factory, unique_arg.const_expr().int_value())); } -absl::optional CelAccuVarMacroExpander(MacroExprFactory& factory, - Expr& target, - absl::Span args) { +std::optional CelAccuVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } From e00189c323d42b0cacafc320aa890eb4d630d394 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Sat, 23 May 2026 02:26:10 -0700 Subject: [PATCH 86/88] No public description PiperOrigin-RevId: 920098804 --- runtime/activation_test.cc | 4 ++-- runtime/function_registry.cc | 5 ++--- runtime/function_registry_test.cc | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/runtime/activation_test.cc b/runtime/activation_test.cc index e6a74f027..4303116a3 100644 --- a/runtime/activation_test.cc +++ b/runtime/activation_test.cc @@ -326,7 +326,7 @@ TEST_F(ActivationTest, MoveAssignment) { "val_provided", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) - -> absl::StatusOr> { return IntValue(42); })); + -> absl::StatusOr> { return IntValue(42); })); moved_from.SetUnknownPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), @@ -377,7 +377,7 @@ TEST_F(ActivationTest, MoveCtor) { "val_provided", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) - -> absl::StatusOr> { return IntValue(42); })); + -> absl::StatusOr> { return IntValue(42); })); moved_from.SetUnknownPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), diff --git a/runtime/function_registry.cc b/runtime/function_registry.cc index ac1e53eb5..59f267255 100644 --- a/runtime/function_registry.cc +++ b/runtime/function_registry.cc @@ -44,14 +44,13 @@ class ActivationFunctionProviderImpl public: ActivationFunctionProviderImpl() = default; - absl::StatusOr> GetFunction( + absl::StatusOr> GetFunction( const cel::FunctionDescriptor& descriptor, const cel::ActivationInterface& activation) const override { std::vector overloads = activation.FindFunctionOverloads(descriptor.name()); - absl::optional matching_overload = - absl::nullopt; + std::optional matching_overload = absl::nullopt; for (const auto& overload : overloads) { if (overload.descriptor.ShapeMatches(descriptor)) { diff --git a/runtime/function_registry_test.cc b/runtime/function_registry_test.cc index af7f5bc06..53916777a 100644 --- a/runtime/function_registry_test.cc +++ b/runtime/function_registry_test.cc @@ -120,7 +120,7 @@ TEST(FunctionRegistryTest, DefaultLazyProviderNoOverloadFound) { ASSERT_THAT(providers, SizeIs(1)); const FunctionProvider& provider = providers[0].provider; ASSERT_OK_AND_ASSIGN( - absl::optional func, + std::optional func, provider.GetFunction({"LazyFunc", false, {cel::Kind::kInt64}}, activation)); @@ -146,7 +146,7 @@ TEST(FunctionRegistryTest, DefaultLazyProviderReturnsImpl) { ASSERT_THAT(providers, SizeIs(1)); const FunctionProvider& provider = providers[0].provider; ASSERT_OK_AND_ASSIGN( - absl::optional func, + std::optional func, provider.GetFunction( FunctionDescriptor("LazyFunction", false, {Kind::kInt}), activation)); From 267f4de9814c320e61b20cdd8fbe6580f2e57ac3 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Sat, 23 May 2026 02:51:01 -0700 Subject: [PATCH 87/88] No public description PiperOrigin-RevId: 920105999 --- extensions/select_optimization.cc | 16 ++++++++-------- extensions/select_optimization_test.cc | 5 +++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc index 44da4c48a..42cad0f92 100644 --- a/extensions/select_optimization.cc +++ b/extensions/select_optimization.cc @@ -153,7 +153,7 @@ Expr MakeSelectPathExpr( // Returns a single select operation based on the inferred type of the operand // and the field name. If the operand type doesn't define the field, returns // nullopt. -absl::optional GetSelectInstruction( +std::optional GetSelectInstruction( const StructType& runtime_type, PlannerContext& planner_context, absl::string_view field_name) { auto field_or = planner_context.type_reflector() @@ -407,13 +407,13 @@ class RewriterImpl : public AstRewriterBase { // support message traversal. const TypeSpec checker_type = ast_.GetTypeOrDyn(operand.id()); - absl::optional rt_type = + std::optional rt_type = (checker_type.has_message_type()) ? GetRuntimeType(checker_type.message_type().type()) : absl::nullopt; if (rt_type.has_value() && (*rt_type).Is()) { const StructType& runtime_type = rt_type->GetStruct(); - absl::optional field_or = + std::optional field_or = GetSelectInstruction(runtime_type, planner_context_, field_name); if (field_or.has_value()) { candidates_[&expr] = std::move(field_or).value(); @@ -538,7 +538,7 @@ class RewriterImpl : public AstRewriterBase { return candidates_.find(operand) != candidates_.end(); } - absl::optional GetRuntimeType(absl::string_view type_name) { + std::optional GetRuntimeType(absl::string_view type_name) { return planner_context_.type_reflector().FindType(type_name).value_or( absl::nullopt); } @@ -582,14 +582,14 @@ class OptimizedSelectImpl { AttributeTrail GetAttributeTrail(const AttributeTrail& operand_trail) const; - absl::optional attribute() const { return attribute_; } + std::optional attribute() const { return attribute_; } const std::vector& qualifiers() const { return qualifiers_; } private: - absl::optional attribute_; + std::optional attribute_; std::vector select_path_; std::vector qualifiers_; bool presence_test_; @@ -597,7 +597,7 @@ class OptimizedSelectImpl { }; // Check for unknowns or missing attributes. -absl::StatusOr> CheckForMarkedAttributes( +absl::StatusOr> CheckForMarkedAttributes( ExecutionFrameBase& frame, const AttributeTrail& attribute_trail) { if (attribute_trail.empty()) { return absl::nullopt; @@ -715,7 +715,7 @@ absl::Status StackMachineImpl::Evaluate(ExecutionFrame* frame) const { // select arguments. // TODO(uncreated-issue/51): add support variable qualifiers attribute_trail = GetAttributeTrail(frame); - CEL_ASSIGN_OR_RETURN(absl::optional value, + CEL_ASSIGN_OR_RETURN(std::optional value, CheckForMarkedAttributes(*frame, attribute_trail)); if (value.has_value()) { frame->value_stack().Pop(kStackInputs); diff --git a/extensions/select_optimization_test.cc b/extensions/select_optimization_test.cc index c07f4c6ad..9d4024098 100644 --- a/extensions/select_optimization_test.cc +++ b/extensions/select_optimization_test.cc @@ -254,8 +254,9 @@ class MockAccessApis : public LegacyTypeInfoApis, public LegacyTypeAccessApis { return nullptr; } - absl::optional FindFieldByName( - absl::string_view field_name) const override { + std::optional< + google::api::expr::runtime::LegacyTypeInfoApis::FieldDescription> + FindFieldByName(absl::string_view field_name) const override { return absl::nullopt; } From b5db1b3dffb7b890f1a05110ad5833ebe0ebdbee Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Sat, 23 May 2026 23:10:55 -0700 Subject: [PATCH 88/88] No public description PiperOrigin-RevId: 920403526 --- eval/compiler/flat_expr_builder.cc | 10 +++++----- eval/compiler/flat_expr_builder_extensions.cc | 2 +- eval/compiler/qualified_reference_resolver.cc | 10 +++++----- eval/compiler/regex_precompilation_optimization.cc | 6 +++--- eval/compiler/resolver.cc | 6 +++--- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index e38c912c0..8558c7007 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -840,7 +840,7 @@ class FlatExprVisitor : public cel::AstVisitor { // Attempt to resolve a select expression as a namespaced identifier for an // enum or type constant value. - absl::optional const_value; + std::optional const_value; int64_t select_root_id = -1; std::string path_candidate; @@ -1080,7 +1080,7 @@ class FlatExprVisitor : public cel::AstVisitor { // Returns the maximum recursion depth of the current program if it is // eligible for recursion, or nullopt if it is not. - absl::optional RecursionEligible() { + std::optional RecursionEligible() { if (!PlanRecursiveProgram() || program_builder_.current() == nullptr) { return absl::nullopt; } @@ -1525,7 +1525,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } } - if (absl::optional depth = RecursionEligible(); depth.has_value()) { + if (std::optional depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != list_expr.elements().size()) { SetProgressStatusError(absl::InternalError( @@ -1855,7 +1855,7 @@ class FlatExprVisitor : public cel::AstVisitor { int64_t expr_id) { absl::string_view ast_name = create_struct_expr.name(); - absl::optional> type; + std::optional> type; CEL_ASSIGN_OR_RETURN(type, resolver_.FindType(ast_name, expr_id)); if (!type.has_value()) { @@ -1932,7 +1932,7 @@ class FlatExprVisitor : public cel::AstVisitor { IndexManager index_manager_; bool enable_optional_types_; - absl::optional block_; + std::optional block_; int max_recursion_depth_ = 0; }; diff --git a/eval/compiler/flat_expr_builder_extensions.cc b/eval/compiler/flat_expr_builder_extensions.cc index 463b48425..e51b64023 100644 --- a/eval/compiler/flat_expr_builder_extensions.cc +++ b/eval/compiler/flat_expr_builder_extensions.cc @@ -98,7 +98,7 @@ size_t Subexpression::ComputeSize() const { return size; } -absl::optional Subexpression::RecursiveDependencyDepth() const { +std::optional Subexpression::RecursiveDependencyDepth() const { auto* tree = absl::get_if(&program_); int depth = 0; if (tree == nullptr) { diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 67f86ebb6..67c14d9b2 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -81,9 +81,9 @@ bool OverloadExists(const Resolver& resolver, absl::string_view name, // Return the qualified name of the most qualified matching overload, or // nullopt if no matches are found. -absl::optional BestOverloadMatch(const Resolver& resolver, - absl::string_view base_name, - int argument_count) { +std::optional BestOverloadMatch(const Resolver& resolver, + absl::string_view base_name, + int argument_count) { if (IsSpecialFunction(base_name)) { return std::string(base_name); } @@ -262,8 +262,8 @@ class ReferenceResolver : public cel::AstRewriterBase { // Convert a select expr sub tree into a namespace name if possible. // If any operand of the top element is a not a select or an ident node, // return nullopt. - absl::optional ToNamespace(const Expr& expr) { - absl::optional maybe_parent_namespace; + std::optional ToNamespace(const Expr& expr) { + std::optional maybe_parent_namespace; if (rewritten_reference_.find(expr.id()) != rewritten_reference_.end()) { // The target expr matches a reference (resolved to an ident decl). // This should not be treated as a function qualifier. diff --git a/eval/compiler/regex_precompilation_optimization.cc b/eval/compiler/regex_precompilation_optimization.cc index b94cae383..455796131 100644 --- a/eval/compiler/regex_precompilation_optimization.cc +++ b/eval/compiler/regex_precompilation_optimization.cc @@ -145,7 +145,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { // Try to check if the regex is valid, whether or not we can actually update // the plan. - absl::optional pattern = + std::optional pattern = GetConstantString(context, subexpression, node, pattern_expr); if (!pattern.has_value()) { return absl::OkStatus(); @@ -168,7 +168,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { } private: - absl::optional GetConstantString( + std::optional GetConstantString( PlannerContext& context, ProgramBuilder::Subexpression* absl_nullable subexpression, const Expr& call_expr, const Expr& re_expr) const { @@ -180,7 +180,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { // Already modified, can't recover the input pattern. return absl::nullopt; } - absl::optional constant; + std::optional constant; if (subexpression->IsRecursive()) { const auto& program = subexpression->recursive_program(); auto deps = program.step->GetDependencies(); diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index 4e3fa3841..17f60eaad 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -102,8 +102,8 @@ absl::Span Resolver::GetPrefixesFor( return namespace_prefixes_; } -absl::optional Resolver::FindConstant(absl::string_view name, - int64_t expr_id) const { +std::optional Resolver::FindConstant(absl::string_view name, + int64_t expr_id) const { auto prefixes = GetPrefixesFor(name); for (const auto& prefix : prefixes) { std::string qualified_name = absl::StrCat(prefix, name); @@ -205,7 +205,7 @@ std::vector Resolver::FindLazyOverloads( return funcs; } -absl::StatusOr>> +absl::StatusOr>> Resolver::FindType(absl::string_view name, int64_t expr_id) const { auto prefixes = GetPrefixesFor(name); for (auto& prefix : prefixes) {