Skip to content

Commit 8956ee2

Browse files
committed
fix: auto-deref for map/slice and conditionals
Ensure pointers are automatically derefer'd in map keys, slice indices, and conditional expressions. Add dereferencing for map keys to prefer exact type matches (e.g. map[*T]) before falling back to deref'd types. Regression test added. Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
1 parent 593f93f commit 8956ee2

File tree

3 files changed

+128
-0
lines changed

3 files changed

+128
-0
lines changed

checker/checker.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,13 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature {
549549

550550
switch base.Kind {
551551
case reflect.Map:
552+
// If the map key is a pointer, we should not dereference the property.
553+
if !prop.AssignableTo(base.Key(&v.config.NtCache)) {
554+
propDeref := prop.Deref(&v.config.NtCache)
555+
if propDeref.AssignableTo(base.Key(&v.config.NtCache)) {
556+
prop = propDeref
557+
}
558+
}
552559
if !prop.AssignableTo(base.Key(&v.config.NtCache)) && !prop.IsUnknown(&v.config.NtCache) {
553560
return v.error(node.Property, "cannot use %s to get an element from %s", prop.String(), base.String())
554561
}
@@ -562,6 +569,7 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature {
562569
return base.Elem(&v.config.NtCache)
563570

564571
case reflect.Array, reflect.Slice:
572+
prop = prop.Deref(&v.config.NtCache)
565573
if !prop.IsInteger && !prop.IsUnknown(&v.config.NtCache) {
566574
return v.error(node.Property, "array elements can only be selected using an integer (got %s)", prop.String())
567575
}
@@ -607,13 +615,15 @@ func (v *Checker) sliceNode(node *ast.SliceNode) Nature {
607615

608616
if node.From != nil {
609617
from := v.visit(node.From)
618+
from = from.Deref(&v.config.NtCache)
610619
if !from.IsInteger && !from.IsUnknown(&v.config.NtCache) {
611620
return v.error(node.From, "non-integer slice index %v", from.String())
612621
}
613622
}
614623

615624
if node.To != nil {
616625
to := v.visit(node.To)
626+
to = to.Deref(&v.config.NtCache)
617627
if !to.IsInteger && !to.IsUnknown(&v.config.NtCache) {
618628
return v.error(node.To, "non-integer slice index %v", to.String())
619629
}
@@ -942,6 +952,7 @@ func (v *Checker) checkBuiltinGet(node *ast.BuiltinNode) Nature {
942952

943953
base := v.visit(node.Arguments[0])
944954
prop := v.visit(node.Arguments[1])
955+
prop = prop.Deref(&v.config.NtCache)
945956

946957
if id, ok := node.Arguments[0].(*ast.IdentifierNode); ok && id.Value == "$env" {
947958
if s, ok := node.Arguments[1].(*ast.StringNode); ok {
@@ -1260,6 +1271,7 @@ func (v *Checker) sequenceNode(node *ast.SequenceNode) Nature {
12601271

12611272
func (v *Checker) conditionalNode(node *ast.ConditionalNode) Nature {
12621273
c := v.visit(node.Cond)
1274+
c = c.Deref(&v.config.NtCache)
12631275
if !c.IsBool() && !c.IsUnknown(&v.config.NtCache) {
12641276
return v.error(node.Cond, "non-bool expression (type %v) used as condition", c.String())
12651277
}

compiler/compiler.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,18 @@ func (c *compiler) MemberNode(node *ast.MemberNode) {
716716

717717
if op == OpFetch {
718718
c.compile(node.Property)
719+
deref := true
720+
// If the map key is a pointer, we should not dereference the property.
721+
if node.Node.Type() != nil && node.Node.Type().Kind() == reflect.Map {
722+
keyType := node.Node.Type().Key()
723+
propType := node.Property.Type()
724+
if propType != nil && propType.AssignableTo(keyType) {
725+
deref = false
726+
}
727+
}
728+
if deref {
729+
c.derefInNeeded(node.Property)
730+
}
719731
c.emit(OpFetch)
720732
} else {
721733
c.emitLocation(node.Location(), op, c.addConstant(
@@ -728,11 +740,13 @@ func (c *compiler) SliceNode(node *ast.SliceNode) {
728740
c.compile(node.Node)
729741
if node.To != nil {
730742
c.compile(node.To)
743+
c.derefInNeeded(node.To)
731744
} else {
732745
c.emit(OpLen)
733746
}
734747
if node.From != nil {
735748
c.compile(node.From)
749+
c.derefInNeeded(node.From)
736750
} else {
737751
c.emitPush(0)
738752
}
@@ -1213,6 +1227,7 @@ func (c *compiler) lookupVariable(name string) (int, bool) {
12131227

12141228
func (c *compiler) ConditionalNode(node *ast.ConditionalNode) {
12151229
c.compile(node.Cond)
1230+
c.derefInNeeded(node.Cond)
12161231
otherwise := c.emit(OpJumpIfFalse, placeholder)
12171232

12181233
c.emit(OpPop)

test/issues/836/issue_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package issue_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/expr-lang/expr"
7+
"github.com/expr-lang/expr/internal/testify/require"
8+
)
9+
10+
type InputStruct struct {
11+
Enabled *bool `json:"enabled"`
12+
}
13+
14+
func TestIssue836(t *testing.T) {
15+
str := "foo"
16+
ptrStr := &str
17+
b := true
18+
ptrBool := &b
19+
i := 1
20+
ptrInt := &i
21+
22+
env := map[string]interface{}{
23+
"ptrStr": ptrStr,
24+
"ptrBool": ptrBool,
25+
"ptrInt": ptrInt,
26+
"arr": []int{1, 2, 3},
27+
"mapPtr": map[*int]int{ptrInt: 42},
28+
}
29+
30+
t.Run("map access with pointer key", func(t *testing.T) {
31+
program, err := expr.Compile(`{"foo": "bar"}[ptrStr]`, expr.Env(env))
32+
require.NoError(t, err)
33+
34+
output, err := expr.Run(program, env)
35+
require.NoError(t, err)
36+
require.Equal(t, "bar", output)
37+
})
38+
39+
t.Run("conditional with pointer condition", func(t *testing.T) {
40+
program, err := expr.Compile(`ptrBool ? 1 : 0`, expr.Env(env))
41+
require.NoError(t, err)
42+
43+
output, err := expr.Run(program, env)
44+
require.NoError(t, err)
45+
require.Equal(t, 1, output)
46+
})
47+
48+
t.Run("get() with pointer key", func(t *testing.T) {
49+
program, err := expr.Compile(`get({"foo": "bar"}, ptrStr)`, expr.Env(env))
50+
require.NoError(t, err)
51+
52+
output, err := expr.Run(program, env)
53+
require.NoError(t, err)
54+
require.Equal(t, "bar", output)
55+
})
56+
57+
t.Run("struct field pointer check in ternary", func(t *testing.T) {
58+
var v InputStruct
59+
// v.Enabled is nil
60+
61+
env := map[string]any{
62+
"v": v,
63+
}
64+
65+
code := `v.Enabled == nil ? 'default' : ( v.Enabled ? 'enabled' : 'disabled' )`
66+
67+
program, err := expr.Compile(code, expr.Env(env))
68+
require.NoError(t, err)
69+
70+
output, err := expr.Run(program, env)
71+
require.NoError(t, err)
72+
require.Equal(t, "default", output)
73+
})
74+
75+
t.Run("struct field pointer check in ternary (enabled)", func(t *testing.T) {
76+
b := true
77+
v := InputStruct{Enabled: &b}
78+
79+
env := map[string]any{
80+
"v": v,
81+
}
82+
83+
code := `v.Enabled == nil ? 'default' : ( v.Enabled ? 'enabled' : 'disabled' )`
84+
85+
program, err := expr.Compile(code, expr.Env(env))
86+
require.NoError(t, err)
87+
88+
output, err := expr.Run(program, env)
89+
require.NoError(t, err)
90+
require.Equal(t, "enabled", output)
91+
})
92+
93+
t.Run("slice with pointer indices", func(t *testing.T) {
94+
program, err := expr.Compile(`arr[ptrInt:ptrInt]`, expr.Env(env))
95+
require.NoError(t, err)
96+
97+
output, err := expr.Run(program, env)
98+
require.NoError(t, err)
99+
require.Equal(t, []int{}, output)
100+
})
101+
}

0 commit comments

Comments
 (0)