diff --git a/jsonschema/json.go b/jsonschema/json.go index 03bb68891..29d15b409 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -48,6 +48,11 @@ type Definition struct { AdditionalProperties any `json:"additionalProperties,omitempty"` // Whether the schema is nullable or not. Nullable bool `json:"nullable,omitempty"` + + // Ref Reference to a definition in $defs or external schema. + Ref string `json:"$ref,omitempty"` + // Defs A map of reusable schema definitions. + Defs map[string]Definition `json:"$defs,omitempty"` } func (d *Definition) MarshalJSON() ([]byte, error) { @@ -67,10 +72,16 @@ func (d *Definition) Unmarshal(content string, v any) error { } func GenerateSchemaForType(v any) (*Definition, error) { - return reflectSchema(reflect.TypeOf(v)) + var defs = make(map[string]Definition) + def, err := reflectSchema(reflect.TypeOf(v), defs) + if err != nil { + return nil, err + } + def.Defs = defs + return def, nil } -func reflectSchema(t reflect.Type) (*Definition, error) { +func reflectSchema(t reflect.Type, defs map[string]Definition) (*Definition, error) { var d Definition switch t.Kind() { case reflect.String: @@ -84,21 +95,32 @@ func reflectSchema(t reflect.Type) (*Definition, error) { d.Type = Boolean case reflect.Slice, reflect.Array: d.Type = Array - items, err := reflectSchema(t.Elem()) + items, err := reflectSchema(t.Elem(), defs) if err != nil { return nil, err } d.Items = items case reflect.Struct: + if t.Name() != "" { + if _, ok := defs[t.Name()]; !ok { + defs[t.Name()] = Definition{} + object, err := reflectSchemaObject(t, defs) + if err != nil { + return nil, err + } + defs[t.Name()] = *object + } + return &Definition{Ref: "#/$defs/" + t.Name()}, nil + } d.Type = Object d.AdditionalProperties = false - object, err := reflectSchemaObject(t) + object, err := reflectSchemaObject(t, defs) if err != nil { return nil, err } d = *object case reflect.Ptr: - definition, err := reflectSchema(t.Elem()) + definition, err := reflectSchema(t.Elem(), defs) if err != nil { return nil, err } @@ -112,7 +134,7 @@ func reflectSchema(t reflect.Type) (*Definition, error) { return &d, nil } -func reflectSchemaObject(t reflect.Type) (*Definition, error) { +func reflectSchemaObject(t reflect.Type, defs map[string]Definition) (*Definition, error) { var d = Definition{ Type: Object, AdditionalProperties: false, @@ -136,7 +158,7 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) { required = false } - item, err := reflectSchema(field.Type) + item, err := reflectSchema(field.Type, defs) if err != nil { return nil, err } diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 84f25fa85..31b54ed1a 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -183,6 +183,17 @@ func TestDefinition_MarshalJSON(t *testing.T) { } func TestStructToSchema(t *testing.T) { + type Tweet struct { + Text string `json:"text"` + } + + type Person struct { + Name string `json:"name,omitempty"` + Age int `json:"age,omitempty"` + Friends []Person `json:"friends,omitempty"` + Tweets []Tweet `json:"tweets,omitempty"` + } + tests := []struct { name string in any @@ -376,6 +387,65 @@ func TestStructToSchema(t *testing.T) { "additionalProperties":false }`, }, + { + name: "Test with $ref and $defs", + in: struct { + Person Person `json:"person"` + Tweets []Tweet `json:"tweets"` + }{}, + want: `{ + "type" : "object", + "properties" : { + "person" : { + "$ref" : "#/$defs/Person" + }, + "tweets" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Tweet" + } + } + }, + "required" : [ "person", "tweets" ], + "additionalProperties" : false, + "$defs" : { + "Person" : { + "type" : "object", + "properties" : { + "age" : { + "type" : "integer" + }, + "friends" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Person" + } + }, + "name" : { + "type" : "string" + }, + "tweets" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Tweet" + } + } + }, + "additionalProperties" : false + }, + "Tweet" : { + "type" : "object", + "properties" : { + "text" : { + "type" : "string" + } + }, + "required" : [ "text" ], + "additionalProperties" : false + } + } +}`, + }, } for _, tt := range tests { diff --git a/jsonschema/validate.go b/jsonschema/validate.go index 49f9b8859..1bd2f809c 100644 --- a/jsonschema/validate.go +++ b/jsonschema/validate.go @@ -5,26 +5,68 @@ import ( "errors" ) +func CollectDefs(def Definition) map[string]Definition { + result := make(map[string]Definition) + collectDefsRecursive(def, result, "#") + return result +} + +func collectDefsRecursive(def Definition, result map[string]Definition, prefix string) { + for k, v := range def.Defs { + path := prefix + "/$defs/" + k + result[path] = v + collectDefsRecursive(v, result, path) + } + for k, sub := range def.Properties { + collectDefsRecursive(sub, result, prefix+"/properties/"+k) + } + if def.Items != nil { + collectDefsRecursive(*def.Items, result, prefix) + } +} + func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error { var data any err := json.Unmarshal(content, &data) if err != nil { return err } - if !Validate(schema, data) { + if !Validate(schema, data, WithDefs(CollectDefs(schema))) { return errors.New("data validation failed against the provided schema") } return json.Unmarshal(content, &v) } -func Validate(schema Definition, data any) bool { +type validateArgs struct { + Defs map[string]Definition +} + +type ValidateOption func(*validateArgs) + +func WithDefs(defs map[string]Definition) ValidateOption { + return func(option *validateArgs) { + option.Defs = defs + } +} + +func Validate(schema Definition, data any, opts ...ValidateOption) bool { + args := validateArgs{} + for _, opt := range opts { + opt(&args) + } + if len(opts) == 0 { + args.Defs = CollectDefs(schema) + } switch schema.Type { case Object: - return validateObject(schema, data) + return validateObject(schema, data, args.Defs) case Array: - return validateArray(schema, data) + return validateArray(schema, data, args.Defs) case String: - _, ok := data.(string) + v, ok := data.(string) + if ok && len(schema.Enum) > 0 { + return contains(schema.Enum, v) + } return ok case Number: // float64 and int _, ok := data.(float64) @@ -45,11 +87,16 @@ func Validate(schema Definition, data any) bool { case Null: return data == nil default: + if schema.Ref != "" && args.Defs != nil { + if v, ok := args.Defs[schema.Ref]; ok { + return Validate(v, data, WithDefs(args.Defs)) + } + } return false } } -func validateObject(schema Definition, data any) bool { +func validateObject(schema Definition, data any, defs map[string]Definition) bool { dataMap, ok := data.(map[string]any) if !ok { return false @@ -61,7 +108,7 @@ func validateObject(schema Definition, data any) bool { } for key, valueSchema := range schema.Properties { value, exists := dataMap[key] - if exists && !Validate(valueSchema, value) { + if exists && !Validate(valueSchema, value, WithDefs(defs)) { return false } else if !exists && contains(schema.Required, key) { return false @@ -70,13 +117,13 @@ func validateObject(schema Definition, data any) bool { return true } -func validateArray(schema Definition, data any) bool { +func validateArray(schema Definition, data any, defs map[string]Definition) bool { dataArray, ok := data.([]any) if !ok { return false } for _, item := range dataArray { - if !Validate(*schema.Items, item) { + if !Validate(*schema.Items, item, WithDefs(defs)) { return false } } diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go index 6fa30ab0c..aefdf4069 100644 --- a/jsonschema/validate_test.go +++ b/jsonschema/validate_test.go @@ -1,6 +1,7 @@ package jsonschema_test import ( + "reflect" "testing" "github.com/sashabaranov/go-openai/jsonschema" @@ -70,6 +71,96 @@ func Test_Validate(t *testing.T) { }, Required: []string{"string"}, }}, false}, + { + "test schema with ref and defs", args{data: map[string]any{ + "person": map[string]any{ + "name": "John", + "gender": "male", + "age": 28, + "profile": map[string]any{ + "full_name": "John Doe", + }, + }, + }, schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }}, true}, + { + "test enum invalid value", args{data: map[string]any{ + "person": map[string]any{ + "name": "John", + "gender": "other", + "age": 28, + "profile": map[string]any{ + "full_name": "John Doe", + }, + }, + }, schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -156,8 +247,100 @@ func TestUnmarshal(t *testing.T) { err := jsonschema.VerifySchemaAndUnmarshal(tt.args.schema, tt.args.content, tt.args.v) if (err != nil) != tt.wantErr { t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) - } else if err == nil { - t.Logf("Unmarshal() v = %+v\n", tt.args.v) + } + }) + } +} + +func TestCollectDefs(t *testing.T) { + type args struct { + schema jsonschema.Definition + } + tests := []struct { + name string + args args + want map[string]jsonschema.Definition + }{ + { + "test collect defs", + args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }, + }, + map[string]jsonschema.Definition{ + "#/$defs/Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "#/$defs/Person/$defs/Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + "#/$defs/Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := jsonschema.CollectDefs(tt.args.schema) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("CollectDefs() = %v, want %v", got, tt.want) } }) }