Skip to content
Open
158 changes: 158 additions & 0 deletions deep-filtering.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ import (
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"sync"

"github.com/survivorbat/go-tsyncmap"
gormqonvert "github.com/survivorbat/gorm-query-convert"
"gorm.io/gorm/schema"

"gorm.io/gorm"
Expand Down Expand Up @@ -81,11 +85,27 @@ func AddDeepFilters(db *gorm.DB, objectType any, filters ...map[string]any) (*go
relationalTypesInfo := getDatabaseFieldsOfType(db.NamingStrategy, schemaInfo)

simpleFilter := map[string]any{}
totalFilterString := ""
functionRegex := regexp.MustCompile(`.*\((.*?)\)+`)
qonvertMap := map[string]string{}

if _, ok := db.Plugins[gormqonvert.New(gormqonvert.CharacterConfig{}).Name()]; ok {
qonvertPlugin := db.Plugins[gormqonvert.New(gormqonvert.CharacterConfig{}).Name()]
qonvertPluginConfig := reflect.ValueOf(qonvertPlugin).Elem().FieldByName("config")
qonvertMap[qonvertPluginConfig.FieldByName("GreaterThanPrefix").String()] = ">"
qonvertMap[qonvertPluginConfig.FieldByName("GreaterOrEqualToPrefix").String()] = ">="
qonvertMap[qonvertPluginConfig.FieldByName("LessThanPrefix").String()] = "<"
qonvertMap[qonvertPluginConfig.FieldByName("LessOrEqualToPrefix").String()] = "<="
qonvertMap[qonvertPluginConfig.FieldByName("NotEqualToPrefix").String()] = "!="
qonvertMap[qonvertPluginConfig.FieldByName("LikePrefix").String()] = "%s LIKE '%%%s%%'"
qonvertMap[qonvertPluginConfig.FieldByName("NotLikePrefix").String()] = "%s NOT LIKE '%%%s%%'"
}

// Go through the filters
for _, filterObject := range filters {
// Go through all the keys of the filters
for fieldName, givenFilter := range filterObject {
filterString := ""
switch givenFilter.(type) {
// WithFilters for relational objects
case map[string]any:
Expand All @@ -105,6 +125,116 @@ func AddDeepFilters(db *gorm.DB, objectType any, filters ...map[string]any) (*go

// Simple filters (string, int, bool etc.)
default:
// If the simple filter contains a function, build the query different
if functionRegex.MatchString(fieldName) {
checkFieldName := functionRegex.ReplaceAllString(fieldName, "$1")
if _, ok := schemaInfo.FieldsByDBName[checkFieldName]; !ok {
return nil, fmt.Errorf("failed to add filters for '%s.%s': %w", schemaInfo.Table, checkFieldName, ErrFieldDoesNotExist)
}

if _, ok := givenFilter.([]string); ok {
for _, filter := range givenFilter.([]string) {
containedQonvert := false
qonvertOperator := ""
qonvertFilter := ""
for qonvertField, qonvertValue := range qonvertMap {
// Find longest possible prefix match in filter
// (e.g. to make sure <= is fully matched and not overwritten by <)
if strings.HasPrefix(filter, qonvertField) && len(qonvertField) > len(qonvertFilter) {
qonvertOperator = qonvertValue
qonvertFilter = qonvertField
containedQonvert = true
}
}

if !containedQonvert {
if filterString == "" {
filterString = fmt.Sprintf("%s = '%s'", fieldName, filter)
} else {
filterString = fmt.Sprintf("%s OR %s", filterString, fmt.Sprintf("%s = '%s'", fieldName, filter))
}
} else {
filter = strings.Replace(filter, qonvertFilter, "", 1)
if strings.Contains(qonvertOperator, "%") {
if filterString == "" {
filterString = fmt.Sprintf(qonvertOperator, fieldName, filter)
} else {
filterString = fmt.Sprintf("%s OR %s", filterString, fmt.Sprintf(qonvertOperator, fieldName, filter))
}
} else {
if filterString == "" {
filterString = prepareFilterValue(fieldName, qonvertOperator, filter)
} else {
filterString = fmt.Sprintf("%s OR %s", filterString, prepareFilterValue(fieldName, qonvertOperator, filter))
}
}
}
}

filterString = fmt.Sprintf("(%s)", filterString)
if totalFilterString != "" {
totalFilterString += " AND "
}
totalFilterString += filterString
break
}

if _, ok := givenFilter.([]int); ok {
for _, filter := range givenFilter.([]int) {
if filterString == "" {
filterString = fmt.Sprintf("%s = %d", fieldName, filter)
} else {
filterString = fmt.Sprintf("%s OR %s", filterString, fmt.Sprintf("%s = %d", fieldName, filter))
}
}

filterString = fmt.Sprintf("(%s)", filterString)
if totalFilterString != "" {
totalFilterString += " AND "
}
totalFilterString += filterString
break
}

containedQonvert := false
qonvertOperator := ""
qonvertFilter := ""
for qonvertField, qonvertValue := range qonvertMap {
// Find longest possible prefix match in filter
// (e.g. to make sure <= is fully matched and not overwritten by <)
if filterStrCast, castOk := givenFilter.(string); castOk && strings.HasPrefix(filterStrCast, qonvertField) && len(qonvertField) > len(qonvertFilter) {
qonvertOperator = qonvertValue
qonvertFilter = qonvertField
containedQonvert = true
}

}

if containedQonvert {
givenFilter = strings.Replace(givenFilter.(string), qonvertFilter, "", 1)
if strings.Contains(qonvertOperator, "%") {
if filterString == "" {
filterString = fmt.Sprintf(qonvertOperator, fieldName, givenFilter)
} else {
filterString = fmt.Sprintf("%s OR %s", filterString, fmt.Sprintf(qonvertOperator, fieldName, givenFilter))
}
} else {
filterString = prepareFilterValue(fieldName, qonvertOperator, givenFilter.(string))
}
}

if filterString == "" {
filterString = prepareFilterValueCast(fieldName, "=", givenFilter)
}

filterString = fmt.Sprintf("(%s)", filterString)
if totalFilterString != "" {
totalFilterString += " AND "
}
totalFilterString += filterString
break
}

if _, ok := schemaInfo.FieldsByDBName[fieldName]; !ok {
return nil, fmt.Errorf("failed to add filters for '%s.%s': %w", schemaInfo.Table, fieldName, ErrFieldDoesNotExist)
}
Expand All @@ -114,6 +244,9 @@ func AddDeepFilters(db *gorm.DB, objectType any, filters ...map[string]any) (*go
}

// Add simple filters
if totalFilterString != "" {
db = db.Where(totalFilterString)
}
db = db.Where(simpleFilter)

return db, nil
Expand Down Expand Up @@ -148,6 +281,31 @@ type iKind[T any] interface {
Elem() T
}

// prepareFilterValue checks if the given filter can be converted to an int or bool and gives back the correct SQL value for it
func prepareFilterValue(fieldName string, operator string, filterValue string) string {
if value, err := strconv.Atoi(filterValue); err == nil {
return fmt.Sprintf("%s %s %d", fieldName, operator, value)
}

if value, err := strconv.ParseBool(filterValue); err == nil {
return fmt.Sprintf("%s %s %t", fieldName, operator, value)
}

return fmt.Sprintf("%s %s '%s'", fieldName, operator, filterValue)
}

// prepareFilterValue checks if the given filter can be converted to an int or bool and gives back the correct SQL value for it
func prepareFilterValueCast(fieldName string, operator string, filterValue any) string {
if filterIntCast, castOk := filterValue.(int); castOk {
return fmt.Sprintf("%s %s %d", fieldName, operator, filterIntCast)
}
if filterBoolCast, castOk := filterValue.(bool); castOk {
return fmt.Sprintf("%s %s %t", fieldName, operator, filterBoolCast)
}

return fmt.Sprintf("%s %s '%s'", fieldName, operator, filterValue.(string))
}

// ensureConcrete ensures that the given value is a value and not a pointer, if it is, convert it to its element type
func ensureConcrete[T iKind[T]](value T) T {
if value.Kind() == reflect.Ptr {
Expand Down
Loading