Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions common/env/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (c *Config) AddVariableDecls(vars ...*decls.VariableDecl) *Config {
if v == nil {
continue
}
convVars[i] = NewVariable(v.Name(), serializeTypeDesc(v.Type()))
convVars[i] = NewVariable(v.Name(), SerializeTypeDesc(v.Type()))
}
return c.AddVariables(convVars...)
}
Expand Down Expand Up @@ -146,9 +146,9 @@ func (c *Config) AddFunctionDecls(funcs ...*decls.FunctionDecl) *Config {
overloadID := o.ID()
args := make([]*TypeDesc, 0, len(o.ArgTypes()))
for _, a := range o.ArgTypes() {
args = append(args, serializeTypeDesc(a))
args = append(args, SerializeTypeDesc(a))
}
ret := serializeTypeDesc(o.ResultType())
ret := SerializeTypeDesc(o.ResultType())
if o.IsMemberFunction() {
overloads = append(overloads, NewMemberOverload(overloadID, args[0], args[1:], ret))
} else {
Expand Down Expand Up @@ -836,7 +836,8 @@ func (td *TypeDesc) AsCELType(tp types.Provider) (*types.Type, error) {
}
}

func serializeTypeDesc(t *types.Type) *TypeDesc {
// SerializeTypeDesc converts *types.Type to a serialized format TypeDesc
func SerializeTypeDesc(t *types.Type) *TypeDesc {
typeName := t.TypeName()
if t.Kind() == types.TypeParamKind {
return NewTypeParam(typeName)
Expand All @@ -848,7 +849,7 @@ func serializeTypeDesc(t *types.Type) *TypeDesc {
}
var params []*TypeDesc
for _, p := range t.Parameters() {
params = append(params, serializeTypeDesc(p))
params = append(params, SerializeTypeDesc(p))
}
return NewTypeDesc(typeName, params...)
}
Expand Down
6 changes: 5 additions & 1 deletion ext/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ go_library(
"bindings.go",
"comprehensions.go",
"encoders.go",
"extension_option_factory.go",
"formatting.go",
"guards.go",
"lists.go",
Expand All @@ -26,6 +27,7 @@ go_library(
"//checker:go_default_library",
"//common/ast:go_default_library",
"//common/decls:go_default_library",
"//common/env:go_default_library",
"//common/overloads:go_default_library",
"//common/operators:go_default_library",
"//common/types:go_default_library",
Expand All @@ -48,7 +50,8 @@ go_test(
srcs = [
"bindings_test.go",
"comprehensions_test.go",
"encoders_test.go",
"encoders_test.go",
"extension_option_factory_test.go",
"lists_test.go",
"math_test.go",
"native_test.go",
Expand All @@ -62,6 +65,7 @@ go_test(
deps = [
"//cel:go_default_library",
"//checker:go_default_library",
"//common/env:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
Expand Down
72 changes: 72 additions & 0 deletions ext/extension_option_factory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ext

import (
"fmt"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/env"
)

// ExtensionOptionFactory converts an ExtensionConfig value to a CEL environment option.
func ExtensionOptionFactory(configElement any) (cel.EnvOption, bool) {
ext, isExtension := configElement.(*env.Extension)
if !isExtension {
return nil, false
}
fac, found := extFactories[ext.Name]
if !found {
return nil, false
}
// If the version is 'latest', set the version value to the max uint.
ver, err := ext.VersionNumber()
if err != nil {
return func(*cel.Env) (*cel.Env, error) {
return nil, fmt.Errorf("invalid extension version: %s - %s", ext.Name, ext.Version)
}, true
}
return fac(ver), true
}

// extensionFactory accepts a version and produces a CEL environment associated with the versioned extension.
type extensionFactory func(uint32) cel.EnvOption

var extFactories = map[string]extensionFactory{
"bindings": func(version uint32) cel.EnvOption {
return Bindings(BindingsVersion(version))
},
"encoders": func(version uint32) cel.EnvOption {
return Encoders(EncodersVersion(version))
},
"lists": func(version uint32) cel.EnvOption {
return Lists(ListsVersion(version))
},
"math": func(version uint32) cel.EnvOption {
return Math(MathVersion(version))
},
"protos": func(version uint32) cel.EnvOption {
return Protos(ProtosVersion(version))
},
"sets": func(version uint32) cel.EnvOption {
return Sets(SetsVersion(version))
},
"strings": func(version uint32) cel.EnvOption {
return Strings(StringsVersion(version))
},
"two-var-comprehensions": func(version uint32) cel.EnvOption {
return TwoVarComprehensions(TwoVarComprehensionsVersion(version))
},
}
67 changes: 67 additions & 0 deletions ext/extension_option_factory_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ext

import (
"fmt"
"testing"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/env"
)

func TestExtensionOptionFactoryInvalidExtension(t *testing.T) {
invalidExtension := "invalid extension"
_, validExtension := ExtensionOptionFactory(invalidExtension)
if validExtension {
t.Fatalf("ExtensionOptionFactory(%s) returned valid extension for invalid input", invalidExtension)
}
}

func TestExtensionOptionFactoryInvalidExtensionName(t *testing.T) {
e := &env.Extension{Name: "invalid extension name"}
_, validExtension := ExtensionOptionFactory(e)
if validExtension {
t.Fatalf("ExtensionOptionFactory(%s) returned valid extension for invalid extension name", e.Name)
}
}

func TestExtensionOptionFactoryInvalidExtensionVersion(t *testing.T) {
e := &env.Extension{Name: "bindings", Version: "invalid version"}
opt, validExtension := ExtensionOptionFactory(e)
if !validExtension {
t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension", e.Name)
}
_, err := cel.NewCustomEnv(opt)
if err == nil || err.Error() != fmt.Sprintf("invalid extension version: %s - %s", e.Name, e.Version) {
t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension version", e.Name)
}
}

func TestExtensionOptionFactoryValidBindingsExtension(t *testing.T) {
e := &env.Extension{Name: "bindings", Version: "latest"}
opt, validExtension := ExtensionOptionFactory(e)
if !validExtension {
t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension", e.Name)
}
en, err := cel.NewCustomEnv(opt)
if err != nil {
t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension", e.Name)
}
cfg, err := en.ToConfig("test config")
if len(cfg.Extensions) != 1 || cfg.Extensions[0].Name != "cel.lib.ext.cel.bindings" || cfg.Extensions[0].Version != "latest" {
t.Fatalf("ExtensionOptionFactory(%s) returned invalid extension", e.Name)
}
}
54 changes: 1 addition & 53 deletions policy/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
package policy

import (
"fmt"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/env"
"github.com/google/cel-go/ext"
Expand All @@ -28,55 +26,5 @@ import (
// a set of configuration ConfigOptionFactory values to handle extensions and other config features
// which may be defined outside of the `cel` package.
func FromConfig(config *env.Config) cel.EnvOption {
return cel.FromConfig(config, extensionOptionFactory)
}

// extensionOptionFactory converts an ExtensionConfig value to a CEL environment option.
func extensionOptionFactory(configElement any) (cel.EnvOption, bool) {
ext, isExtension := configElement.(*env.Extension)
if !isExtension {
return nil, false
}
fac, found := extFactories[ext.Name]
if !found {
return nil, false
}
// If the version is 'latest', set the version value to the max uint.
ver, err := ext.VersionNumber()
if err != nil {
return func(*cel.Env) (*cel.Env, error) {
return nil, fmt.Errorf("invalid extension version: %s - %s", ext.Name, ext.Version)
}, true
}
return fac(ver), true
}

// extensionFactory accepts a version and produces a CEL environment associated with the versioned extension.
type extensionFactory func(uint32) cel.EnvOption

var extFactories = map[string]extensionFactory{
"bindings": func(version uint32) cel.EnvOption {
return ext.Bindings(ext.BindingsVersion(version))
},
"encoders": func(version uint32) cel.EnvOption {
return ext.Encoders(ext.EncodersVersion(version))
},
"lists": func(version uint32) cel.EnvOption {
return ext.Lists(ext.ListsVersion(version))
},
"math": func(version uint32) cel.EnvOption {
return ext.Math(ext.MathVersion(version))
},
"protos": func(version uint32) cel.EnvOption {
return ext.Protos(ext.ProtosVersion(version))
},
"sets": func(version uint32) cel.EnvOption {
return ext.Sets(ext.SetsVersion(version))
},
"strings": func(version uint32) cel.EnvOption {
return ext.Strings(ext.StringsVersion(version))
},
"two-var-comprehensions": func(version uint32) cel.EnvOption {
return ext.TwoVarComprehensions(ext.TwoVarComprehensionsVersion(version))
},
return cel.FromConfig(config, ext.ExtensionOptionFactory)
}