diff --git a/converter.go b/converter.go index c5dec6d..8f092a1 100644 --- a/converter.go +++ b/converter.go @@ -64,74 +64,38 @@ func proccessPath(path string) string { return path } -// converter returns string to target type converter for a reflect.StructField -func converter(f reflect.StructField) func(s string) (interface{}, error) { - switch f.Type.Kind() { - case reflect.Bool: - return func(s string) (interface{}, error) { - v, err := strconv.ParseBool(s) - return v, err - } - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: +func converter(t reflect.Type) func(s string) (interface{}, error) { + st, sf := toSwaggerType(t) + if st == "integer" && sf == "int32" { return func(s string) (interface{}, error) { v, err := strconv.Atoi(s) return v, err } - case reflect.Int64, reflect.Uint64: + } else if st == "integer" && sf == "int64" { return func(s string) (interface{}, error) { v, err := strconv.ParseInt(s, 10, 64) return v, err } - case reflect.Float32: + } else if st == "number" && sf == "float" { return func(s string) (interface{}, error) { v, err := strconv.ParseFloat(s, 32) return float32(v), err } - case reflect.Float64: + } else if st == "number" && sf == "double" { return func(s string) (interface{}, error) { v, err := strconv.ParseFloat(s, 64) return v, err } - default: + } else if st == "boolean" && sf == "boolean" { + return func(s string) (interface{}, error) { + v, err := strconv.ParseBool(s) + return v, err + } + } else if st == "array" && sf == "array" { + return converter(t.Elem()) + } else { return func(s string) (interface{}, error) { return s, nil } } } - -func asString(rv reflect.Value) (string, bool) { - switch rv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - av := rv.Int() - if av != 0 { - return strconv.FormatInt(av, 10), true - } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - av := rv.Uint() - if av != 0 { - return strconv.FormatUint(av, 10), true - } - case reflect.Float64: - av := rv.Float() - if av != 0 { - return strconv.FormatFloat(av, 'g', -1, 64), true - } - case reflect.Float32: - av := rv.Float() - if av != 0 { - return strconv.FormatFloat(av, 'g', -1, 32), true - } - case reflect.Bool: - av := rv.Bool() - if av { - return strconv.FormatBool(av), true - } - case reflect.String: - av := rv.String() - if av != "" { - return av, true - } - } - return "", false -} diff --git a/generator.go b/generator.go index 3598c02..23b1271 100644 --- a/generator.go +++ b/generator.go @@ -87,8 +87,9 @@ func (r *RawDefineDic) genSchema(v reflect.Value) *JSONSchema { } else { schema.Type = JSONType(st) schema.Format = sf - if ex, ok := asString(v); ok { - schema.Example = ex + zv := reflect.Zero(v.Type()) + if v.CanInterface() && zv.CanInterface() && v.Interface() != zv.Interface() { + schema.Example = v.Interface() } } return schema diff --git a/internal_test.go b/internal_test.go new file mode 100644 index 0000000..e14dce2 --- /dev/null +++ b/internal_test.go @@ -0,0 +1,161 @@ +package echoswagger + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParamTypes(t *testing.T) { + var pa interface{} + var pb *int64 + var pc map[string]string + var pd [][]float64 + tests := []struct { + p interface{} + panic bool + name string + }{ + { + p: pa, + panic: true, + name: "Interface type", + }, + { + p: &pa, + panic: true, + name: "Interface pointer type", + }, + { + p: &pb, + panic: false, + name: "Int type", + }, + { + p: &pc, + panic: true, + name: "Map type", + }, + { + p: nil, + panic: true, + name: "Nil type", + }, + { + p: 0, + panic: false, + name: "Int type", + }, + { + p: &pd, + panic: false, + name: "Array float64 type", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := prepareApi() + if tt.panic { + assert.Panics(t, func() { + a.AddParamPath(tt.p, tt.name, "") + }) + } else { + a.AddParamPath(tt.p, tt.name, "") + sapi, ok := a.(*api) + assert.Equal(t, ok, true) + assert.Equal(t, len(sapi.operation.Parameters), 1) + assert.Equal(t, tt.name, sapi.operation.Parameters[0].Name) + } + }) + } +} + +func TestSchemaTypes(t *testing.T) { + var pa interface{} + var pb map[string]string + type PT struct { + Name string + ExpiredAt time.Time + } + var pc map[PT]string + var pd PT + var pe map[time.Time]string + var pf map[*int]string + type PU struct { + Unknown interface{} + } + var pg PU + tests := []struct { + p interface{} + panic bool + name string + }{ + { + p: pa, + panic: true, + name: "Interface type", + }, + { + p: nil, + panic: true, + name: "Nil type", + }, + { + p: "", + panic: false, + name: "String type", + }, + { + p: &pb, + panic: false, + name: "Map type", + }, + { + p: &pc, + panic: true, + name: "Map struct type", + }, + { + p: pd, + panic: false, + name: "Struct type", + }, + { + p: &pd, + panic: false, + name: "Struct pointer type", + }, + { + p: &pe, + panic: false, + name: "Map time.Time key type", + }, + { + p: &pf, + panic: false, + name: "Map pointer key type", + }, + { + p: &pg, + panic: true, + name: "Struct inner invalid type", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := prepareApi() + if tt.panic { + assert.Panics(t, func() { + a.AddParamBody(tt.p, tt.name, "", true) + }) + } else { + a.AddParamBody(tt.p, tt.name, "", true) + sapi, ok := a.(*api) + assert.Equal(t, ok, true) + assert.Equal(t, len(sapi.operation.Parameters), 1) + assert.Equal(t, tt.name, sapi.operation.Parameters[0].Name) + } + }) + } +} diff --git a/spec_test.go b/spec_test.go new file mode 100644 index 0000000..c21c633 --- /dev/null +++ b/spec_test.go @@ -0,0 +1,141 @@ +package echoswagger + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/labstack/echo" + "github.com/stretchr/testify/assert" +) + +var handleWithFilter func(handlerFunc echo.HandlerFunc, c echo.Context) error + +func TestSpec(t *testing.T) { + t.Run("Basic", func(t *testing.T) { + r := prepareApiRoot() + e := r.(*Root).echo + req := httptest.NewRequest(echo.GET, "/doc/swagger.json", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + j := `{"swagger":"2.0","info":{"title":"Project APIs","version":""},"host":"example.com","basePath":"/","schemes":["http"],"paths":{}}` + if assert.NoError(t, r.(*Root).Spec(c)) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.JSONEq(t, j, rec.Body.String()) + } + }) + + t.Run("Methods", func(t *testing.T) { + r := prepareApiRoot() + var h echo.HandlerFunc + r.GET("/", h) + r.POST("/", h) + r.PUT("/", h) + r.DELETE("/", h) + r.OPTIONS("/", h) + r.HEAD("/", h) + r.PATCH("/", h) + e := r.(*Root).echo + req := httptest.NewRequest(echo.GET, "/doc/swagger.json", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if assert.NoError(t, r.(*Root).Spec(c)) { + assert.Equal(t, http.StatusOK, rec.Code) + s := r.(*Root).spec + assert.Len(t, s.Paths, 1) + assert.NotNil(t, s.Paths["/"].(*Path).Get) + assert.NotNil(t, s.Paths["/"].(*Path).Post) + assert.NotNil(t, s.Paths["/"].(*Path).Put) + assert.NotNil(t, s.Paths["/"].(*Path).Delete) + assert.NotNil(t, s.Paths["/"].(*Path).Options) + assert.NotNil(t, s.Paths["/"].(*Path).Head) + assert.NotNil(t, s.Paths["/"].(*Path).Patch) + } + }) + + t.Run("ErrorGroupSecurity", func(t *testing.T) { + r := prepareApiRoot() + e := r.(*Root).echo + var h echo.HandlerFunc + r.Group("G", "/g").SetSecurity("JWT").GET("/", h) + req := httptest.NewRequest(echo.GET, "/doc/swagger.json", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if assert.NoError(t, r.(*Root).Spec(c)) { + assert.Equal(t, http.StatusInternalServerError, rec.Code) + } + }) + + t.Run("ErrorApiSecurity", func(t *testing.T) { + r := prepareApiRoot() + e := r.(*Root).echo + var h echo.HandlerFunc + r.GET("/", h).SetSecurity("JWT") + req := httptest.NewRequest(echo.GET, "/doc/swagger.json", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if assert.NoError(t, r.(*Root).Spec(c)) { + assert.Equal(t, http.StatusInternalServerError, rec.Code) + } + }) + + t.Run("CleanUp", func(t *testing.T) { + r := prepareApiRoot() + e := r.(*Root).echo + g := r.Group("Users", "users") + + var ha echo.HandlerFunc + g.DELETE("/:id", ha) + + var hb echo.HandlerFunc + r.GET("/ping", hb) + + req := httptest.NewRequest(echo.GET, "/doc/swagger.json", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + j := `{"swagger":"2.0","info":{"title":"Project APIs","version":""},"host":"example.com","basePath":"/","schemes":["http"],"paths":{"/ping":{"get":{"responses":{"default":{"description":"successful operation"}}}},"/users/{id}":{"delete":{"tags":["Users"],"responses":{"default":{"description":"successful operation"}}}}},"tags":[{"name":"Users"}]}` + if assert.NoError(t, r.(*Root).Spec(c)) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.JSONEq(t, j, rec.Body.String()) + } + + assert.Nil(t, r.(*Root).echo) + assert.Nil(t, r.(*Root).defs) + assert.Len(t, r.(*Root).groups, 0) + assert.Len(t, r.(*Root).apis, 0) + }) +} + +func TestAddDefinition(t *testing.T) { + type DA struct { + Name string + DB struct { + Name string + } + } + var da DA + r := prepareApiRoot() + var h echo.HandlerFunc + a := r.GET("/", h) + a.AddParamBody(&da, "DA", "DA Struct", false) + assert.Equal(t, len(a.(*api).operation.Parameters), 1) + assert.Equal(t, "DA", a.(*api).operation.Parameters[0].Name) + assert.Equal(t, "DA Struct", a.(*api).operation.Parameters[0].Description) + assert.Equal(t, "body", a.(*api).operation.Parameters[0].In) + assert.NotNil(t, a.(*api).operation.Parameters[0].Schema) + assert.Equal(t, "#/definitions/DA", a.(*api).operation.Parameters[0].Schema.Ref) + + assert.NotNil(t, a.(*api).defs) + assert.Equal(t, reflect.ValueOf(&da).Elem(), (*a.(*api).defs)["DA"].Value) + assert.Equal(t, reflect.ValueOf(&da.DB).Elem(), (*a.(*api).defs)[""].Value) + + e := r.(*Root).echo + req := httptest.NewRequest(echo.GET, "/doc/swagger.json", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if assert.NoError(t, r.(*Root).Spec(c)) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Len(t, r.(*Root).spec.Definitions, 2) + } +} diff --git a/tag.go b/tag.go index f1672a9..3cd07d0 100644 --- a/tag.go +++ b/tag.go @@ -88,7 +88,7 @@ func (p *Parameter) handleSwaggerTags(field reflect.StructField, name string, in p.Required = true } - convert := converter(field) + convert := converter(field.Type) if t, ok := tags["enum"]; ok { enums := strings.Split(t, "|") var es []interface{} @@ -129,9 +129,9 @@ func (p *Parameter) handleSwaggerTags(field reflect.StructField, name string, in } } -func (s *JSONSchema) handleSwaggerTags(field reflect.StructField, name string) { +func (s *JSONSchema) handleSwaggerTags(f reflect.StructField, name string) { propSchema := s.Properties[name] - tags := getSwaggerTags(field) + tags := getSwaggerTags(f) if t, ok := tags["desc"]; ok { propSchema.Description = t @@ -163,7 +163,7 @@ func (s *JSONSchema) handleSwaggerTags(field reflect.StructField, name string) { propSchema.ReadOnly = true } - convert := converter(field) + convert := converter(f.Type) if t, ok := tags["enum"]; ok { enums := strings.Split(t, "|") var es []interface{} @@ -201,8 +201,8 @@ func (s *JSONSchema) handleSwaggerTags(field reflect.StructField, name string) { } } -func (h *Header) handleSwaggerTags(field reflect.StructField, name string) { - tags := getSwaggerTags(field) +func (h *Header) handleSwaggerTags(f reflect.StructField, name string) { + tags := getSwaggerTags(f) var collect string if t, ok := tags["collect"]; ok && contains([]string{"csv", "ssv", "tsv", "pipes"}, t) { @@ -232,7 +232,7 @@ func (h *Header) handleSwaggerTags(field reflect.StructField, name string) { } } - convert := converter(field) + convert := converter(f.Type) if t, ok := tags["enum"]; ok { enums := strings.Split(t, "|") var es []interface{} diff --git a/tag_test.go b/tag_test.go new file mode 100644 index 0000000..af2dbfe --- /dev/null +++ b/tag_test.go @@ -0,0 +1,224 @@ +package echoswagger + +import ( + "encoding/xml" + "net/http" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSchemaSwaggerTags(t *testing.T) { + type Spot struct { + Address string `swagger:"desc(Address of Spot)"` + Matrix [][]bool `swagger:"default(true)"` + } + + type User struct { + Age int `swagger:"min(0),max(99)"` + Gender string `swagger:"enum(male|female|other),required"` + CarNos string `swagger:"minLen(5),maxLen(8)"` + Spots []*Spot `swagger:"required"` + Money **float64 `swagger:"default(0),readOnly"` + } + + a := prepareApi() + a.AddParamBody(&User{}, "Body", "", true) + sapi := a.(*api) + assert.Len(t, sapi.operation.Parameters, 1) + assert.Len(t, *sapi.defs, 2) + + su := (*sapi.defs)["User"].Schema + pu := su.Properties + assert.NotNil(t, su) + assert.NotNil(t, pu) + assert.Len(t, su.Required, 2) + assert.ElementsMatch(t, su.Required, []string{"Spots", "Gender"}) + assert.Equal(t, *pu["Age"].Minimum, float64(0)) + assert.Equal(t, *pu["Age"].Maximum, float64(99)) + assert.Len(t, pu["Gender"].Enum, 3) + assert.ElementsMatch(t, pu["Gender"].Enum, []string{"male", "female", "other"}) + assert.Equal(t, *pu["CarNos"].MinLength, int(5)) + assert.Equal(t, *pu["CarNos"].MaxLength, int(8)) + assert.Equal(t, pu["Money"].DefaultValue, float64(0)) + assert.Equal(t, pu["Money"].ReadOnly, true) + + ss := (*sapi.defs)["Spot"].Schema + ps := ss.Properties + assert.NotNil(t, ss) + assert.NotNil(t, ps) + assert.Equal(t, ps["Address"].Description, "Address of Spot") + assert.Equal(t, ps["Matrix"].Items.Items.DefaultValue, true) +} + +func TestParamSwaggerTags(t *testing.T) { + type SearchInput struct { + Q string `query:"q" swagger:"minLen(5),maxLen(8)"` + BrandIds string `query:"brandIds" swagger:"collect(csv),allowEmpty"` + Sortby [][]string `query:"sortby" swagger:"default(id),allowEmpty"` + Order []int `query:"order" swagger:"enum(0|1|n)"` + SkipCount int `query:"skipCount" swagger:"min(0),max(999)"` + MaxResultCount int `query:"maxResultCount" swagger:"desc(items count in one page)"` + } + + a := prepareApi() + a.AddParamQueryNested(SearchInput{}) + o := a.(*api).operation + assert.Len(t, o.Parameters, 6) + assert.Equal(t, *o.Parameters[0].MinLength, 5) + assert.Equal(t, *o.Parameters[0].MaxLength, 8) + assert.Equal(t, o.Parameters[1].CollectionFormat, "csv") + assert.Equal(t, o.Parameters[1].AllowEmptyValue, true) + assert.Equal(t, o.Parameters[2].AllowEmptyValue, true) + assert.Equal(t, o.Parameters[2].Items.Items.Default, "id") + assert.ElementsMatch(t, o.Parameters[3].Items.Enum, []int{0, 1}) + assert.Equal(t, *o.Parameters[4].Minimum, float64(0)) + assert.Equal(t, *o.Parameters[4].Maximum, float64(999)) + assert.Equal(t, o.Parameters[5].Description, "items count in one page") +} + +func TestHeaderSwaggerTags(t *testing.T) { + type SearchInput struct { + Q string `json:"q" swagger:"minLen(5),maxLen(8)"` + BrandIds string `json:"brandIds" swagger:"collect(csv)"` + Sortby [][]string `json:"sortby" swagger:"default(id)"` + Order []int `json:"order" swagger:"enum(0|1|n)"` + SkipCount int `json:"skipCount" swagger:"min(0),max(999)"` + MaxResultCount int `json:"maxResultCount" swagger:"desc(items count in one page)"` + } + + a := prepareApi() + a.AddResponse(http.StatusOK, "Resp", nil, SearchInput{}) + o := a.(*api).operation + c := strconv.Itoa(http.StatusOK) + h := o.Responses[c].Headers + assert.Len(t, h, 6) + assert.Equal(t, *h["q"].MinLength, 5) + assert.Equal(t, *h["q"].MaxLength, 8) + assert.Equal(t, h["brandIds"].CollectionFormat, "csv") + assert.Equal(t, h["sortby"].Items.Items.Default, "id") + assert.ElementsMatch(t, h["order"].Items.Enum, []int{0, 1}) + assert.Equal(t, *h["skipCount"].Minimum, float64(0)) + assert.Equal(t, *h["skipCount"].Maximum, float64(999)) + assert.Equal(t, h["maxResultCount"].Description, "items count in one page") +} + +func TestXMLTags(t *testing.T) { + type Spot struct { + Id int64 `xml:",attr"` + Comment string `xml:",comment"` + Address string `xml:"AddressDetail"` + Enable bool `xml:"-"` + } + + type User struct { + X xml.Name `xml:"Users"` + Spots []*Spot `xml:"Spots>Spot"` + } + + a := prepareApi() + a.AddParamBody(&User{}, "Body", "", true) + sapi, ok := a.(*api) + assert.Equal(t, ok, true) + assert.Len(t, sapi.operation.Parameters, 1) + assert.Len(t, *sapi.defs, 2) + + su := (*sapi.defs)["User"].Schema + pu := su.Properties + assert.NotNil(t, su) + assert.NotNil(t, pu) + assert.Equal(t, su.XML.Name, "Users") + assert.NotNil(t, pu["Spots"].XML) + assert.Equal(t, pu["Spots"].XML.Name, "Spots") + assert.Equal(t, pu["Spots"].XML.Wrapped, true) + + ss := (*sapi.defs)["Spot"].Schema + ps := ss.Properties + assert.NotNil(t, ss) + assert.NotNil(t, ps) + assert.Equal(t, ss.XML.Name, "Spot") + assert.Equal(t, ps["Id"].XML.Attribute, "Id") + assert.Nil(t, ps["Comment"].XML) + assert.Equal(t, ps["Address"].XML.Name, "AddressDetail") + assert.Nil(t, ps["Enable"].XML) +} + +func TestEnumInSchema(t *testing.T) { + type User struct { + Id int64 `swagger:"enum(0|-1|200000|9.9)"` + Age int `swagger:"enum(0|-1|200000|9.9)"` + Status string `swagger:"enum(normal|stop)"` + Amount float64 `swagger:"enum(0|-0.1|ok|200.555)"` + Grade float32 `swagger:"enum(0|-0.5|ok|200.5)"` + Deleted bool `swagger:"enum(t|F),default(True)"` + } + + a := prepareApi() + a.AddParamBody(&User{}, "Body", "", true) + sapi, ok := a.(*api) + assert.Equal(t, ok, true) + assert.Len(t, sapi.operation.Parameters, 1) + assert.Len(t, *sapi.defs, 1) + + s := (*sapi.defs)["User"].Schema + assert.NotNil(t, s) + + p := s.Properties + + assert.Len(t, p["Id"].Enum, 3) + assert.ElementsMatch(t, p["Id"].Enum, []interface{}{int64(0), int64(-1), int64(200000)}) + + assert.Len(t, p["Age"].Enum, 3) + assert.ElementsMatch(t, p["Age"].Enum, []interface{}{0, -1, 200000}) + + assert.Len(t, p["Status"].Enum, 2) + assert.ElementsMatch(t, p["Status"].Enum, []interface{}{"normal", "stop"}) + + assert.Len(t, p["Amount"].Enum, 3) + assert.ElementsMatch(t, p["Amount"].Enum, []interface{}{float64(0), float64(-0.1), float64(200.555)}) + + assert.Len(t, p["Grade"].Enum, 3) + assert.ElementsMatch(t, p["Grade"].Enum, []interface{}{float32(0), float32(-0.5), float32(200.5)}) + + assert.Len(t, p["Deleted"].Enum, 2) + assert.ElementsMatch(t, p["Deleted"].Enum, []interface{}{true, false}) + assert.Equal(t, p["Deleted"].DefaultValue, true) +} + +func TestExampleInSchema(t *testing.T) { + u := struct { + Id int64 + Age int + Status string + Amount float64 + Grade float32 + Deleted bool + }{ + Id: 10000000001, + Age: 18, + Status: "normal", + Amount: 195.50, + Grade: 5.5, + Deleted: true, + } + + a := prepareApi() + a.AddParamBody(u, "Body", "", true) + sapi, ok := a.(*api) + assert.Equal(t, ok, true) + assert.Len(t, sapi.operation.Parameters, 1) + assert.Len(t, *sapi.defs, 1) + + s := (*sapi.defs)[""].Schema + assert.NotNil(t, s) + + p := s.Properties + + assert.Equal(t, p["Id"].Example, u.Id) + assert.Equal(t, p["Age"].Example, u.Age) + assert.Equal(t, p["Status"].Example, u.Status) + assert.Equal(t, p["Amount"].Example, u.Amount) + assert.Equal(t, p["Grade"].Example, u.Grade) + assert.Equal(t, p["Deleted"].Example, u.Deleted) +} diff --git a/utils.go b/utils.go index 9fc5f11..2ac203c 100644 --- a/utils.go +++ b/utils.go @@ -40,29 +40,19 @@ func equals(a []string, b []string) bool { } func indirect(v reflect.Value) reflect.Value { - t := v.Type() - v = reflect.Indirect(v) - if !v.IsValid() { - v = reflect.New(t) - } if v.Kind() == reflect.Ptr { - return indirect(v) + ev := v.Elem() + if !ev.IsValid() { + ev = reflect.New(v.Type().Elem()) + } + return indirect(ev) } return v } func indirectValue(p interface{}) reflect.Value { v := reflect.ValueOf(p) -LoopValue: - v = reflect.Indirect(v) - if !v.IsValid() { - v = reflect.New(reflect.TypeOf(p)) - } - if v.Kind() == reflect.Ptr { - goto LoopValue - } - // TODO 遍历所有子项,为Invalid初始化Value - return v + return indirect(v) } func indirectType(p interface{}) reflect.Type {