#9 Fix endless loop for recursive struct type

This commit is contained in:
ElvinChan 2018-09-18 11:59:50 +08:00
parent 0e96d943fb
commit dcf9221c4b
3 changed files with 113 additions and 10 deletions

View File

@ -12,6 +12,12 @@ func TestParamTypes(t *testing.T) {
var pb *int64 var pb *int64
var pc map[string]string var pc map[string]string
var pd [][]float64 var pd [][]float64
type Parent struct {
Child struct {
Name string
}
}
var pe Parent
tests := []struct { tests := []struct {
p interface{} p interface{}
panic bool panic bool
@ -52,6 +58,11 @@ func TestParamTypes(t *testing.T) {
panic: false, panic: false,
name: "Array float64 type", name: "Array float64 type",
}, },
{
p: &pe,
panic: true,
name: "Struct type",
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -71,6 +82,53 @@ func TestParamTypes(t *testing.T) {
} }
} }
func TestNestedParamTypes(t *testing.T) {
var pa struct {
Name string
}
var pb struct {
User struct {
Name string
}
}
tests := []struct {
p interface{}
panic bool
name string
}{
{
p: 0,
panic: true,
name: "Basic type",
},
{
p: pa,
panic: false,
name: "Struct type",
},
{
p: pb,
panic: true,
name: "Nested struct type",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := prepareApi()
if tt.panic {
assert.Panics(t, func() {
a.AddParamPathNested(tt.p)
})
} else {
a.AddParamPathNested(tt.p)
sapi, ok := a.(*api)
assert.Equal(t, ok, true)
assert.Equal(t, len(sapi.operation.Parameters), 1)
}
})
}
}
func TestSchemaTypes(t *testing.T) { func TestSchemaTypes(t *testing.T) {
var pa interface{} var pa interface{}
var pb map[string]string var pb map[string]string
@ -159,3 +217,41 @@ func TestSchemaTypes(t *testing.T) {
}) })
} }
} }
type testUser struct {
Id int64
Name string
Pets []testPet
}
type testPet struct {
Id int64
Masters []testUser
}
func TestRecursiveTypes(t *testing.T) {
tests := []struct {
p interface{}
name string
}{
{
p: &testUser{},
name: "User",
},
{
p: &testPet{},
name: "Pet",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := prepareApi()
a.AddParamBody(tt.p, tt.name, "", true)
sapi, ok := a.(*api)
assert.Equal(t, ok, true)
assert.Equal(t, len(sapi.operation.Parameters), 1)
assert.Equal(t, len(*sapi.defs), 2)
assert.Equal(t, tt.name, sapi.operation.Parameters[0].Name)
})
}
}

10
spec.go
View File

@ -116,6 +116,11 @@ func (r *RawDefineDic) addDefinition(v reflect.Value) string {
Properties: make(map[string]*JSONSchema), Properties: make(map[string]*JSONSchema),
} }
(*r)[key] = RawDefine{
Value: v,
Schema: schema,
}
for i := 0; i < v.NumField(); i++ { for i := 0; i < v.NumField(); i++ {
f := v.Type().Field(i) f := v.Type().Field(i)
name := getFieldName(f, ParamInBody) name := getFieldName(f, ParamInBody)
@ -136,11 +141,6 @@ func (r *RawDefineDic) addDefinition(v reflect.Value) string {
schema.handleSwaggerTags(f, name) schema.handleSwaggerTags(f, name)
} }
(*r)[key] = RawDefine{
Value: v,
Schema: schema,
}
if schema.XML == nil { if schema.XML == nil {
schema.XML = &XMLSchema{} schema.XML = &XMLSchema{}
} }

View File

@ -65,27 +65,34 @@ func isValidParam(t reflect.Type, nest, inner bool) bool {
// invalid case: // invalid case:
// 1. interface{} // 1. interface{}
// 2. Map[Struct]string // 2. Map[Struct]string
func isValidSchema(t reflect.Type, inner bool) bool { func isValidSchema(t reflect.Type, inner bool, pres ...reflect.Type) bool {
if t == nil { if t == nil {
return false return false
} }
for _, pre := range pres {
if t == pre {
return true
}
}
switch t.Kind() { switch t.Kind() {
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
reflect.Float32, reflect.Float64, reflect.String: reflect.Float32, reflect.Float64, reflect.String:
return true return true
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
return isValidSchema(t.Elem(), inner) return isValidSchema(t.Elem(), inner, pres...)
case reflect.Map: case reflect.Map:
return isBasicType(t.Key()) && isValidSchema(t.Elem(), true) return isBasicType(t.Key()) && isValidSchema(t.Elem(), true, pres...)
case reflect.Ptr: case reflect.Ptr:
return isValidSchema(t.Elem(), inner) return isValidSchema(t.Elem(), inner, pres...)
case reflect.Struct: case reflect.Struct:
pres = append(pres, t)
if t == reflect.TypeOf(time.Time{}) { if t == reflect.TypeOf(time.Time{}) {
return true return true
} }
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
if !isValidSchema(t.Field(i).Type, true) { if !isValidSchema(t.Field(i).Type, true, pres...) {
return false return false
} }
} }