From dcf9221c4b3443d1822a19a92f98834beb038995 Mon Sep 17 00:00:00 2001 From: ElvinChan Date: Tue, 18 Sep 2018 11:59:50 +0800 Subject: [PATCH] #9 Fix endless loop for recursive struct type --- internal_test.go | 96 ++++++++++++++++++++++++++++++++++++++++++++++++ spec.go | 10 ++--- validator.go | 17 ++++++--- 3 files changed, 113 insertions(+), 10 deletions(-) diff --git a/internal_test.go b/internal_test.go index e14dce2..38dec01 100644 --- a/internal_test.go +++ b/internal_test.go @@ -12,6 +12,12 @@ func TestParamTypes(t *testing.T) { var pb *int64 var pc map[string]string var pd [][]float64 + type Parent struct { + Child struct { + Name string + } + } + var pe Parent tests := []struct { p interface{} panic bool @@ -52,6 +58,11 @@ func TestParamTypes(t *testing.T) { panic: false, name: "Array float64 type", }, + { + p: &pe, + panic: true, + name: "Struct type", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -71,6 +82,53 @@ func TestParamTypes(t *testing.T) { } } +func TestNestedParamTypes(t *testing.T) { + var pa struct { + Name string + } + var pb struct { + User struct { + Name string + } + } + tests := []struct { + p interface{} + panic bool + name string + }{ + { + p: 0, + panic: true, + name: "Basic type", + }, + { + p: pa, + panic: false, + name: "Struct type", + }, + { + p: pb, + panic: true, + name: "Nested struct type", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := prepareApi() + if tt.panic { + assert.Panics(t, func() { + a.AddParamPathNested(tt.p) + }) + } else { + a.AddParamPathNested(tt.p) + sapi, ok := a.(*api) + assert.Equal(t, ok, true) + assert.Equal(t, len(sapi.operation.Parameters), 1) + } + }) + } +} + func TestSchemaTypes(t *testing.T) { var pa interface{} var pb map[string]string @@ -159,3 +217,41 @@ func TestSchemaTypes(t *testing.T) { }) } } + +type testUser struct { + Id int64 + Name string + Pets []testPet +} + +type testPet struct { + Id int64 + Masters []testUser +} + +func TestRecursiveTypes(t *testing.T) { + tests := []struct { + p interface{} + name string + }{ + { + p: &testUser{}, + name: "User", + }, + { + p: &testPet{}, + name: "Pet", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := prepareApi() + a.AddParamBody(tt.p, tt.name, "", true) + sapi, ok := a.(*api) + assert.Equal(t, ok, true) + assert.Equal(t, len(sapi.operation.Parameters), 1) + assert.Equal(t, len(*sapi.defs), 2) + assert.Equal(t, tt.name, sapi.operation.Parameters[0].Name) + }) + } +} diff --git a/spec.go b/spec.go index 680bc8b..cd60e94 100644 --- a/spec.go +++ b/spec.go @@ -116,6 +116,11 @@ func (r *RawDefineDic) addDefinition(v reflect.Value) string { Properties: make(map[string]*JSONSchema), } + (*r)[key] = RawDefine{ + Value: v, + Schema: schema, + } + for i := 0; i < v.NumField(); i++ { f := v.Type().Field(i) name := getFieldName(f, ParamInBody) @@ -136,11 +141,6 @@ func (r *RawDefineDic) addDefinition(v reflect.Value) string { schema.handleSwaggerTags(f, name) } - (*r)[key] = RawDefine{ - Value: v, - Schema: schema, - } - if schema.XML == nil { schema.XML = &XMLSchema{} } diff --git a/validator.go b/validator.go index 49fe5d1..2689003 100644 --- a/validator.go +++ b/validator.go @@ -65,27 +65,34 @@ func isValidParam(t reflect.Type, nest, inner bool) bool { // invalid case: // 1. interface{} // 2. Map[Struct]string -func isValidSchema(t reflect.Type, inner bool) bool { +func isValidSchema(t reflect.Type, inner bool, pres ...reflect.Type) bool { if t == nil { return false } + for _, pre := range pres { + if t == pre { + return true + } + } + switch t.Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.String: return true case reflect.Array, reflect.Slice: - return isValidSchema(t.Elem(), inner) + return isValidSchema(t.Elem(), inner, pres...) case reflect.Map: - return isBasicType(t.Key()) && isValidSchema(t.Elem(), true) + return isBasicType(t.Key()) && isValidSchema(t.Elem(), true, pres...) case reflect.Ptr: - return isValidSchema(t.Elem(), inner) + return isValidSchema(t.Elem(), inner, pres...) case reflect.Struct: + pres = append(pres, t) if t == reflect.TypeOf(time.Time{}) { return true } for i := 0; i < t.NumField(); i++ { - if !isValidSchema(t.Field(i).Type, true) { + if !isValidSchema(t.Field(i).Type, true, pres...) { return false } }