Skip to content

Commit 8effd0f

Browse files
authored
Refactor proxy.shouldTerminate function and move the functionality to Act.Registry (#615)
* Refactor proxy.shouldTerminate function and move the functionality to Act.Registry * Simplify syntax and avoid unnecessary loops (using maps.Keys) * Create RunAll and ShouldTerminate functions in Act.Registry * Add logger to RunAll * Add test for RunAll and ShouldTerminate * Separate the check for existence of 'outputs' key from type check * Add test case for failures * Address comments by @sinadarbouy
1 parent e8596c5 commit 8effd0f

File tree

3 files changed

+153
-35
lines changed

3 files changed

+153
-35
lines changed

act/registry.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@ import (
1212
"github.com/gatewayd-io/gatewayd/config"
1313
gerr "github.com/gatewayd-io/gatewayd/errors"
1414
"github.com/rs/zerolog"
15+
"github.com/spf13/cast"
1516
)
1617

1718
type IRegistry interface {
1819
Add(policy *sdkAct.Policy)
1920
Apply(signals []sdkAct.Signal, hook sdkAct.Hook) []*sdkAct.Output
2021
Run(output *sdkAct.Output, params ...sdkAct.Parameter) (any, *gerr.GatewayDError)
22+
RunAll(result map[string]any) map[string]any
23+
ShouldTerminate(result map[string]any) bool
2124
}
2225

2326
// Registry keeps track of all policies and actions.
@@ -402,6 +405,70 @@ func runActionWithTimeout(
402405
}
403406
}
404407

408+
// RunAll run all the actions in the outputs and returns the end result.
409+
func (r *Registry) RunAll(result map[string]any) map[string]any {
410+
if _, exists := result[sdkAct.Outputs]; !exists {
411+
r.Logger.Debug().Msg("Outputs key is not present, returning the result as-is")
412+
return result
413+
}
414+
415+
var (
416+
outputs []*sdkAct.Output
417+
ok bool
418+
)
419+
if outputs, ok = result[sdkAct.Outputs].([]*sdkAct.Output); !ok || len(outputs) == 0 {
420+
r.Logger.Debug().Msg("Outputs are nil or empty, returning the result as-is")
421+
// If the outputs are nil or empty, we should delete the key from the result.
422+
delete(result, sdkAct.Outputs)
423+
return result
424+
}
425+
426+
endResult := make(map[string]any)
427+
for _, output := range outputs {
428+
if !cast.ToBool(output.Verdict) {
429+
r.Logger.Debug().Msg(
430+
"Skipping the action, because the verdict of the policy execution is false")
431+
continue
432+
}
433+
runResult, err := r.Run(output, WithResult(result), WithLogger(r.Logger))
434+
// If the action is async and we received a sentinel error, don't log the error.
435+
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
436+
r.Logger.Error().Err(err).Msg("Error running policy")
437+
}
438+
// Each action should return a map.
439+
if v, ok := runResult.(map[string]any); ok {
440+
endResult = v
441+
} else {
442+
r.Logger.Debug().Msg("Run result is not a map, skipping merging into end result.")
443+
}
444+
}
445+
return endResult
446+
}
447+
448+
// ShouldTerminate checks if any of the actions are terminal, indicating that the request
449+
// should be terminated.
450+
// This is an optimization to avoid executing the actions' functions unnecessarily.
451+
// The __terminal__ field is only set when an action intends to terminate the request.
452+
func (r *Registry) ShouldTerminate(result map[string]any) bool {
453+
terminalVal, exists := result[sdkAct.Terminal]
454+
if !exists {
455+
r.Logger.Debug().Msg("Terminal key not found, request will continue.")
456+
return false
457+
}
458+
459+
shouldTerminate, ok := terminalVal.(bool)
460+
if !ok {
461+
r.Logger.Debug().Msg("Terminal key exists but cannot be cast to a boolean.")
462+
return false
463+
}
464+
465+
if shouldTerminate {
466+
r.Logger.Debug().Msg("Request is marked as terminal. Terminating.")
467+
}
468+
469+
return shouldTerminate
470+
}
471+
405472
// WithLogger returns a parameter with the Logger to be used by the action.
406473
// This is automatically prepended to the parameters when running an action.
407474
func WithLogger(logger zerolog.Logger) sdkAct.Parameter {

act/registry_test.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,3 +930,87 @@ func Test_Run_Timeout(t *testing.T) {
930930
})
931931
}
932932
}
933+
934+
// Test_RunAll_And_ShouldTerminate tests the RunAll function of the act registry
935+
// with a terminal action (and signal).
936+
func Test_RunAll_And_ShouldTerminate(t *testing.T) {
937+
out := bytes.Buffer{}
938+
logger := zerolog.New(&out)
939+
actRegistry := NewActRegistry(
940+
Registry{
941+
Signals: BuiltinSignals(),
942+
Policies: BuiltinPolicies(),
943+
Actions: BuiltinActions(),
944+
DefaultPolicyName: config.DefaultPolicy,
945+
PolicyTimeout: config.DefaultPolicyTimeout,
946+
DefaultActionTimeout: config.DefaultActionTimeout,
947+
Logger: logger,
948+
})
949+
assert.NotNil(t, actRegistry)
950+
951+
outputs := actRegistry.Apply([]sdkAct.Signal{
952+
*sdkAct.Terminate(),
953+
*sdkAct.Log("info", "testing log via Act", map[string]any{"test": true}),
954+
}, sdkAct.Hook{
955+
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
956+
Priority: 1000,
957+
Params: map[string]any{},
958+
Result: map[string]any{},
959+
})
960+
assert.NotNil(t, outputs)
961+
962+
// This is what the hook returns along with "request", "response" and other fields.
963+
// These two keys and values should exist in the result after policy execution.
964+
result := map[string]any{
965+
sdkAct.Outputs: outputs,
966+
sdkAct.Terminal: true,
967+
}
968+
969+
assert.True(t, actRegistry.ShouldTerminate(result))
970+
971+
result = actRegistry.RunAll(result)
972+
973+
time.Sleep(time.Millisecond) // wait for async action to complete
974+
975+
assert.NotEmpty(t, result)
976+
// Terminate action does nothing when run. It is just a signal to terminate.
977+
assert.Contains(t, out.String(),
978+
`{"level":"debug","action":"terminate","executionMode":"sync","message":"Running action"}`)
979+
assert.Contains(t, out.String(),
980+
`{"level":"debug","action":"log","executionMode":"async","message":"Running action"}`)
981+
assert.Contains(t, out.String(), `{"level":"info","test":true,"message":"testing log via Act"}`)
982+
}
983+
984+
// Test_RunAll_Empty_Result tests the RunAll function of the act registry with an empty result.
985+
func Test_RunAll_Empty_Result(t *testing.T) {
986+
out := bytes.Buffer{}
987+
logger := zerolog.New(&out)
988+
actRegistry := NewActRegistry(
989+
Registry{
990+
Signals: BuiltinSignals(),
991+
Policies: BuiltinPolicies(),
992+
Actions: BuiltinActions(),
993+
DefaultPolicyName: config.DefaultPolicy,
994+
PolicyTimeout: config.DefaultPolicyTimeout,
995+
DefaultActionTimeout: config.DefaultActionTimeout,
996+
Logger: logger,
997+
})
998+
assert.NotNil(t, actRegistry)
999+
1000+
results := []map[string]any{
1001+
{},
1002+
{
1003+
sdkAct.Outputs: false, // This is invalid, hence it will be removed.
1004+
},
1005+
}
1006+
1007+
for _, result := range results {
1008+
assert.False(t, actRegistry.ShouldTerminate(result))
1009+
1010+
result = actRegistry.RunAll(result)
1011+
1012+
time.Sleep(time.Millisecond) // wait for async action to complete
1013+
1014+
assert.Empty(t, result)
1015+
}
1016+
}

network/proxy.go

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@ import (
66
"errors"
77
"io"
88
"net"
9-
"slices"
109
"time"
1110

12-
sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
1311
"github.com/gatewayd-io/gatewayd-plugin-sdk/databases/postgres"
1412
v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1"
15-
"github.com/gatewayd-io/gatewayd/act"
1613
"github.com/gatewayd-io/gatewayd/config"
1714
gerr "github.com/gatewayd-io/gatewayd/errors"
1815
"github.com/gatewayd-io/gatewayd/metrics"
@@ -21,9 +18,7 @@ import (
2118
"github.com/getsentry/sentry-go"
2219
"github.com/go-co-op/gocron"
2320
"github.com/rs/zerolog"
24-
"github.com/spf13/cast"
2521
"go.opentelemetry.io/otel"
26-
"golang.org/x/exp/maps"
2722
)
2823

2924
//nolint:interfacebloat
@@ -873,36 +868,8 @@ func (pr *Proxy) shouldTerminate(result map[string]any) (bool, map[string]any) {
873868
return false, result
874869
}
875870

876-
outputs, ok := result[sdkAct.Outputs].([]*sdkAct.Output)
877-
if !ok {
878-
pr.Logger.Error().Msg("Failed to cast the outputs to the []*act.Output type")
879-
return false, result
880-
}
881-
882-
// This is a shortcut to avoid running the actions' functions.
883-
// The Terminal field is only present if the action wants to terminate the request,
884-
// that is the `__terminal__` field is set in one of the outputs.
885-
keys := maps.Keys(result)
886-
terminate := slices.Contains(keys, sdkAct.Terminal) && cast.ToBool(result[sdkAct.Terminal])
887-
actionResult := make(map[string]any)
888-
for _, output := range outputs {
889-
if !cast.ToBool(output.Verdict) {
890-
pr.Logger.Debug().Msg(
891-
"Skipping the action, because the verdict of the policy execution is false")
892-
continue
893-
}
894-
actRes, err := pr.PluginRegistry.ActRegistry.Run(
895-
output, act.WithResult(result))
896-
// If the action is async and we received a sentinel error,
897-
// don't log the error.
898-
if err != nil && !errors.Is(err, gerr.ErrAsyncAction) {
899-
pr.Logger.Error().Err(err).Msg("Error running policy")
900-
}
901-
// The terminate action should return a map.
902-
if v, ok := actRes.(map[string]any); ok {
903-
actionResult = v
904-
}
905-
}
871+
terminate := pr.PluginRegistry.ActRegistry.ShouldTerminate(result)
872+
actionResult := pr.PluginRegistry.ActRegistry.RunAll(result)
906873
if terminate {
907874
pr.Logger.Debug().Fields(
908875
map[string]any{

0 commit comments

Comments
 (0)