Merge pull request #10 from elvinchan/9
#9 Fix endless loop for recursive struct type
This commit is contained in:
commit
b84391a9dc
@ -12,6 +12,12 @@ func TestParamTypes(t *testing.T) {
|
|||||||
var pb *int64
|
var pb *int64
|
||||||
var pc map[string]string
|
var pc map[string]string
|
||||||
var pd [][]float64
|
var pd [][]float64
|
||||||
|
type Parent struct {
|
||||||
|
Child struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var pe Parent
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
p interface{}
|
p interface{}
|
||||||
panic bool
|
panic bool
|
||||||
@ -52,6 +58,11 @@ func TestParamTypes(t *testing.T) {
|
|||||||
panic: false,
|
panic: false,
|
||||||
name: "Array float64 type",
|
name: "Array float64 type",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
p: &pe,
|
||||||
|
panic: true,
|
||||||
|
name: "Struct type",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
@ -71,6 +82,53 @@ func TestParamTypes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNestedParamTypes(t *testing.T) {
|
||||||
|
var pa struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
var pb struct {
|
||||||
|
User struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
p interface{}
|
||||||
|
panic bool
|
||||||
|
name string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
p: 0,
|
||||||
|
panic: true,
|
||||||
|
name: "Basic type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
p: pa,
|
||||||
|
panic: false,
|
||||||
|
name: "Struct type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
p: pb,
|
||||||
|
panic: true,
|
||||||
|
name: "Nested struct type",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := prepareApi()
|
||||||
|
if tt.panic {
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
a.AddParamPathNested(tt.p)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
a.AddParamPathNested(tt.p)
|
||||||
|
sapi, ok := a.(*api)
|
||||||
|
assert.Equal(t, ok, true)
|
||||||
|
assert.Equal(t, len(sapi.operation.Parameters), 1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSchemaTypes(t *testing.T) {
|
func TestSchemaTypes(t *testing.T) {
|
||||||
var pa interface{}
|
var pa interface{}
|
||||||
var pb map[string]string
|
var pb map[string]string
|
||||||
@ -159,3 +217,41 @@ func TestSchemaTypes(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type testUser struct {
|
||||||
|
Id int64
|
||||||
|
Name string
|
||||||
|
Pets []testPet
|
||||||
|
}
|
||||||
|
|
||||||
|
type testPet struct {
|
||||||
|
Id int64
|
||||||
|
Masters []testUser
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecursiveTypes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
p interface{}
|
||||||
|
name string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
p: &testUser{},
|
||||||
|
name: "User",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
p: &testPet{},
|
||||||
|
name: "Pet",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := prepareApi()
|
||||||
|
a.AddParamBody(tt.p, tt.name, "", true)
|
||||||
|
sapi, ok := a.(*api)
|
||||||
|
assert.Equal(t, ok, true)
|
||||||
|
assert.Equal(t, len(sapi.operation.Parameters), 1)
|
||||||
|
assert.Equal(t, len(*sapi.defs), 2)
|
||||||
|
assert.Equal(t, tt.name, sapi.operation.Parameters[0].Name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
10
spec.go
10
spec.go
@ -116,6 +116,11 @@ func (r *RawDefineDic) addDefinition(v reflect.Value) string {
|
|||||||
Properties: make(map[string]*JSONSchema),
|
Properties: make(map[string]*JSONSchema),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
(*r)[key] = RawDefine{
|
||||||
|
Value: v,
|
||||||
|
Schema: schema,
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < v.NumField(); i++ {
|
for i := 0; i < v.NumField(); i++ {
|
||||||
f := v.Type().Field(i)
|
f := v.Type().Field(i)
|
||||||
name := getFieldName(f, ParamInBody)
|
name := getFieldName(f, ParamInBody)
|
||||||
@ -136,11 +141,6 @@ func (r *RawDefineDic) addDefinition(v reflect.Value) string {
|
|||||||
schema.handleSwaggerTags(f, name)
|
schema.handleSwaggerTags(f, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
(*r)[key] = RawDefine{
|
|
||||||
Value: v,
|
|
||||||
Schema: schema,
|
|
||||||
}
|
|
||||||
|
|
||||||
if schema.XML == nil {
|
if schema.XML == nil {
|
||||||
schema.XML = &XMLSchema{}
|
schema.XML = &XMLSchema{}
|
||||||
}
|
}
|
||||||
|
17
validator.go
17
validator.go
@ -65,27 +65,34 @@ func isValidParam(t reflect.Type, nest, inner bool) bool {
|
|||||||
// invalid case:
|
// invalid case:
|
||||||
// 1. interface{}
|
// 1. interface{}
|
||||||
// 2. Map[Struct]string
|
// 2. Map[Struct]string
|
||||||
func isValidSchema(t reflect.Type, inner bool) bool {
|
func isValidSchema(t reflect.Type, inner bool, pres ...reflect.Type) bool {
|
||||||
if t == nil {
|
if t == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
for _, pre := range pres {
|
||||||
|
if t == pre {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch t.Kind() {
|
switch t.Kind() {
|
||||||
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
|
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
|
||||||
reflect.Float32, reflect.Float64, reflect.String:
|
reflect.Float32, reflect.Float64, reflect.String:
|
||||||
return true
|
return true
|
||||||
case reflect.Array, reflect.Slice:
|
case reflect.Array, reflect.Slice:
|
||||||
return isValidSchema(t.Elem(), inner)
|
return isValidSchema(t.Elem(), inner, pres...)
|
||||||
case reflect.Map:
|
case reflect.Map:
|
||||||
return isBasicType(t.Key()) && isValidSchema(t.Elem(), true)
|
return isBasicType(t.Key()) && isValidSchema(t.Elem(), true, pres...)
|
||||||
case reflect.Ptr:
|
case reflect.Ptr:
|
||||||
return isValidSchema(t.Elem(), inner)
|
return isValidSchema(t.Elem(), inner, pres...)
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
|
pres = append(pres, t)
|
||||||
if t == reflect.TypeOf(time.Time{}) {
|
if t == reflect.TypeOf(time.Time{}) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
for i := 0; i < t.NumField(); i++ {
|
for i := 0; i < t.NumField(); i++ {
|
||||||
if !isValidSchema(t.Field(i).Type, true) {
|
if !isValidSchema(t.Field(i).Type, true, pres...) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user