diff --git a/.gitignore b/.gitignore index f1c181e..5553e59 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,6 @@ # Output of the go coverage tool, specifically when used with LiteIDE *.out + +.DS_Store +tmp \ No newline at end of file diff --git a/converter.go b/converter.go index aa2e691..c5dec6d 100644 --- a/converter.go +++ b/converter.go @@ -1,8 +1,137 @@ package echoswagger +import ( + "reflect" + "strconv" + "strings" + "time" +) + +// toSwaggerType returns type、format for a reflect.Type in swagger format +func toSwaggerType(t reflect.Type) (string, string) { + if t == reflect.TypeOf(time.Time{}) { + return "string", "date-time" + } + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return "integer", "int32" + case reflect.Int64, reflect.Uint64: + return "integer", "int64" + case reflect.Float32: + return "number", "float" + case reflect.Float64: + return "number", "double" + case reflect.String: + return "string", "string" + case reflect.Bool: + return "boolean", "boolean" + case reflect.Struct: + return "object", "object" + case reflect.Map: + return "object", "map" + case reflect.Array, reflect.Slice: + return "array", "array" + case reflect.Ptr: + return toSwaggerType(t.Elem()) + default: + return "string", "string" + } +} + +// toSwaggerPath returns path in swagger format +func toSwaggerPath(path string) string { + var params []string + for i := 0; i < len(path); i++ { + if path[i] == ':' { + j := i + 1 + for ; i < len(path) && path[i] != '/'; i++ { + } + params = append(params, path[j:i]) + } + } + + for _, name := range params { + path = strings.Replace(path, ":"+name, "{"+name+"}", 1) + } + return proccessPath(path) +} + func proccessPath(path string) string { if len(path) == 0 || path[0] != '/' { path = "/" + path } return path } + +// converter returns string to target type converter for a reflect.StructField +func converter(f reflect.StructField) func(s string) (interface{}, error) { + switch f.Type.Kind() { + case reflect.Bool: + return func(s string) (interface{}, error) { + v, err := strconv.ParseBool(s) + return v, err + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + return func(s string) (interface{}, error) { + v, err := strconv.Atoi(s) + return v, err + } + case reflect.Int64, reflect.Uint64: + return func(s string) (interface{}, error) { + v, err := strconv.ParseInt(s, 10, 64) + return v, err + } + case reflect.Float32: + return func(s string) (interface{}, error) { + v, err := strconv.ParseFloat(s, 32) + return float32(v), err + } + case reflect.Float64: + return func(s string) (interface{}, error) { + v, err := strconv.ParseFloat(s, 64) + return v, err + } + default: + return func(s string) (interface{}, error) { + return s, nil + } + } +} + +func asString(rv reflect.Value) (string, bool) { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + av := rv.Int() + if av != 0 { + return strconv.FormatInt(av, 10), true + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + av := rv.Uint() + if av != 0 { + return strconv.FormatUint(av, 10), true + } + case reflect.Float64: + av := rv.Float() + if av != 0 { + return strconv.FormatFloat(av, 'g', -1, 64), true + } + case reflect.Float32: + av := rv.Float() + if av != 0 { + return strconv.FormatFloat(av, 'g', -1, 32), true + } + case reflect.Bool: + av := rv.Bool() + if av { + return strconv.FormatBool(av), true + } + case reflect.String: + av := rv.String() + if av != "" { + return av, true + } + } + return "", false +} diff --git a/generator.go b/generator.go new file mode 100644 index 0000000..3598c02 --- /dev/null +++ b/generator.go @@ -0,0 +1,112 @@ +package echoswagger + +import ( + "reflect" +) + +func (Items) generate(t reflect.Type) *Items { + st, sf := toSwaggerType(t) + item := &Items{ + Type: st, + } + if st == "array" { + item.Items = Items{}.generate(t.Elem()) + item.CollectionFormat = "multi" + } else { + item.Format = sf + } + return item +} + +func (Parameter) generate(f reflect.StructField, in ParamInType) *Parameter { + name := getFieldName(f, in) + if name == "-" { + return nil + } + st, sf := toSwaggerType(f.Type) + pm := &Parameter{ + Name: name, + In: string(in), + Type: st, + } + if st == "array" { + pm.Items = Items{}.generate(f.Type.Elem()) + pm.CollectionFormat = "multi" + } else { + pm.Format = sf + } + + pm.handleSwaggerTags(f, name, in) + return pm +} + +func (Header) generate(f reflect.StructField) *Header { + name := getFieldName(f, ParamInHeader) + if name == "-" { + return nil + } + st, sf := toSwaggerType(f.Type) + h := &Header{ + Type: st, + } + if st == "array" { + h.Items = Items{}.generate(f.Type.Elem()) + h.CollectionFormat = "multi" + } else { + h.Format = sf + } + + h.handleSwaggerTags(f, name) + return h +} + +func (r *RawDefineDic) genSchema(v reflect.Value) *JSONSchema { + if !v.IsValid() { + return nil + } + v = indirect(v) + st, sf := toSwaggerType(v.Type()) + schema := &JSONSchema{} + if st == "array" { + schema.Type = JSONType(st) + if v.Len() == 0 { + v = reflect.MakeSlice(v.Type(), 1, 1) + } + schema.Items = r.genSchema(v.Index(0)) + } else if st == "object" && sf == "map" { + schema.Type = JSONType(st) + if v.Len() == 0 { + v = reflect.New(v.Type().Elem()) + } else { + v = v.MapIndex(v.MapKeys()[0]) + } + schema.AdditionalProperties = r.genSchema(v) + } else if st == "object" { + key := r.addDefinition(v) + schema.Ref = DefPrefix + key + } else { + schema.Type = JSONType(st) + schema.Format = sf + if ex, ok := asString(v); ok { + schema.Example = ex + } + } + return schema +} + +func (api) genHeader(v reflect.Value) map[string]*Header { + rt := indirect(v).Type() + if rt.Kind() != reflect.Struct { + return nil + } + mh := make(map[string]*Header) + for i := 0; i < rt.NumField(); i++ { + f := rt.Field(i) + h := Header{}.generate(f) + if h != nil { + name := getFieldName(f, ParamInHeader) + mh[name] = h + } + } + return mh +} diff --git a/internal.go b/internal.go new file mode 100644 index 0000000..f5e87c5 --- /dev/null +++ b/internal.go @@ -0,0 +1,142 @@ +package echoswagger + +import ( + "bytes" + "html/template" + "net/http" + "reflect" + + "github.com/labstack/echo" +) + +type ParamInType string + +const ( + ParamInQuery ParamInType = "query" + ParamInHeader ParamInType = "header" + ParamInPath ParamInType = "path" + ParamInFormData ParamInType = "formData" + ParamInBody ParamInType = "body" +) + +type UISetting struct { + HideTop bool + CDN string +} + +type RawDefineDic map[string]RawDefine + +type RawDefine struct { + Value reflect.Value + Schema *JSONSchema +} + +func (r *Root) docHandler(swaggerPath string) echo.HandlerFunc { + t, err := template.New("swagger").Parse(SwaggerUIContent) + if err != nil { + panic(err) + } + + return func(c echo.Context) error { + buf := new(bytes.Buffer) + t.Execute(buf, map[string]interface{}{ + "title": r.spec.Info.Title, + "url": c.Scheme() + "://" + c.Request().Host + swaggerPath, + }) + return c.HTMLBlob(http.StatusOK, buf.Bytes()) + } +} + +func (r *RawDefineDic) getKey(v reflect.Value) (bool, string) { + for k, d := range *r { + if reflect.DeepEqual(d.Value.Interface(), v.Interface()) { + return true, k + } + } + name := v.Type().Name() + for k := range *r { + if name == k { + name += "_" + } + } + return false, name +} + +func (r *routers) appendRoute(method string, route *echo.Route) *api { + opr := Operation{ + Responses: make(map[string]*Response), + } + a := api{ + route: route, + defs: r.defs, + method: method, + operation: opr, + } + r.apis = append(r.apis, a) + return &r.apis[len(r.apis)-1] +} + +func (g *api) addParams(p interface{}, in ParamInType, name, desc string, required, nest bool) Api { + if !isValidParam(reflect.TypeOf(p), nest, false) { + panic("echoswagger: invalid " + string(in) + " param") + } + rt := indirectType(p) + st, sf := toSwaggerType(rt) + if st == "object" && sf == "object" { + for i := 0; i < rt.NumField(); i++ { + pm := Parameter{}.generate(rt.Field(i), in) + if pm != nil { + pm.Name = g.operation.rename(pm.Name) + g.operation.Parameters = append(g.operation.Parameters, pm) + } + } + } else { + name = g.operation.rename(name) + pm := &Parameter{ + Name: name, + In: string(in), + Description: desc, + Required: required, + Type: st, + } + if st == "array" { + pm.Items = Items{}.generate(rt.Elem()) + pm.CollectionFormat = "multi" + } else { + pm.Format = sf + } + g.operation.Parameters = append(g.operation.Parameters, pm) + } + return g +} + +func (g *api) addBodyParams(p interface{}, name, desc string, required bool) Api { + if !isValidSchema(reflect.TypeOf(p), false) { + panic("echoswagger: invalid body parameter") + } + for _, param := range g.operation.Parameters { + if param.In == string(ParamInBody) { + panic("echoswagger: multiple body parameters are not allowed") + } + } + + rv := indirectValue(p) + pm := &Parameter{ + Name: name, + In: string(ParamInBody), + Description: desc, + Required: required, + Schema: g.defs.genSchema(rv), + } + g.operation.Parameters = append(g.operation.Parameters, pm) + return g +} + +func (o Operation) rename(s string) string { + for _, p := range o.Parameters { + if p.Name == s { + return o.rename(s + "_") + } + } + return s +} diff --git a/security.go b/security.go new file mode 100644 index 0000000..7a8e8b0 --- /dev/null +++ b/security.go @@ -0,0 +1,75 @@ +package echoswagger + +import "errors" + +type SecurityType string + +const ( + SecurityBasic SecurityType = "basic" + SecurityOAuth2 SecurityType = "oauth2" + SecurityAPIKey SecurityType = "apiKey" +) + +type SecurityInType string + +const ( + SecurityInQuery SecurityInType = "query" + SecurityInHeader SecurityInType = "header" +) + +type OAuth2FlowType string + +const ( + OAuth2FlowImplicit OAuth2FlowType = "implicit" + OAuth2FlowPassword OAuth2FlowType = "password" + OAuth2FlowApplication OAuth2FlowType = "application" + OAuth2FlowAccessCode OAuth2FlowType = "accessCode" +) + +func (r *Root) checkSecurity(name string) bool { + if name == "" { + return false + } + if _, ok := r.spec.SecurityDefinitions[name]; ok { + return false + } + return true +} + +func setSecurity(security []map[string][]string, names ...string) []map[string][]string { + m := make(map[string][]string) + for _, name := range names { + m[name] = make([]string, 0) + } + return append(security, m) +} + +func setSecurityWithScope(security []map[string][]string, s ...map[string][]string) []map[string][]string { + for _, t := range s { + if len(t) == 0 { + continue + } + for k, v := range t { + if len(v) == 0 { + t[k] = make([]string, 0) + } + } + security = append(security, t) + } + return security +} + +func (o *Operation) addSecurity(defs map[string]*SecurityDefinition, security []map[string][]string) error { + for _, scy := range security { + for k := range scy { + if _, ok := defs[k]; !ok { + return errors.New("echoswagger: not found SecurityDefinition with name: " + k) + } + } + if containsMap(o.Security, scy) { + continue + } + o.Security = append(o.Security, scy) + } + return nil +} diff --git a/spec.go b/spec.go index 7b247f6..f63524b 100644 --- a/spec.go +++ b/spec.go @@ -1,20 +1,25 @@ package echoswagger import ( - "bytes" - "fmt" - "html/template" + "encoding/xml" "net/http" + "reflect" "github.com/labstack/echo" ) -const SwaggerVersion = "2.0" +const ( + DefPrefix = "#/definitions/" + SwaggerVersion = "2.0" +) func (r *Root) Spec(c echo.Context) error { - err := r.genSpec(c) - if err != nil { - return c.String(http.StatusInternalServerError, err.Error()) + r.once.Do(func() { + r.err = r.genSpec(c) + r.cleanUp() + }) + if r.err != nil { + return c.String(http.StatusInternalServerError, r.err.Error()) } return c.JSON(http.StatusOK, r.spec) } @@ -30,26 +35,114 @@ func (r *Root) genSpec(c echo.Context) error { r.spec.Tags = append(r.spec.Tags, &group.tag) for j := range group.apis { a := &group.apis[j] - fmt.Println("Group:", a) - // TODO + if err := a.operation.addSecurity(r.spec.SecurityDefinitions, group.security); err != nil { + return err + } + if err := r.transfer(a); err != nil { + return err + } } } + for i := range r.apis { + if err := r.transfer(&r.apis[i]); err != nil { + return err + } + } + + for k, v := range *r.defs { + r.spec.Definitions[k] = v.Schema + } return nil } -func (r *Root) docHandler(swaggerPath string) echo.HandlerFunc { - t, err := template.New("swagger").Parse(SwaggerUIContent) - if err != nil { - panic(err) +func (r *Root) transfer(a *api) error { + if err := a.operation.addSecurity(r.spec.SecurityDefinitions, a.security); err != nil { + return err } - return func(c echo.Context) error { - buf := new(bytes.Buffer) - t.Execute(buf, map[string]interface{}{ - "title": r.spec.Info.Title, - "url": c.Scheme() + "://" + c.Request().Host + swaggerPath, - }) - return c.HTMLBlob(http.StatusOK, buf.Bytes()) + path := toSwaggerPath(a.route.Path) + if len(a.operation.Responses) == 0 { + a.operation.Responses["default"] = &Response{ + Description: "successful operation", + } + } + + if p, ok := r.spec.Paths[path]; ok { + p.(*Path).oprationAssign(a.method, &a.operation) + } else { + p := &Path{} + p.oprationAssign(a.method, &a.operation) + r.spec.Paths[path] = p + } + return nil +} + +func (p *Path) oprationAssign(method string, operation *Operation) { + switch method { + case echo.GET: + p.Get = operation + case echo.POST: + p.Post = operation + case echo.PUT: + p.Put = operation + case echo.PATCH: + p.Patch = operation + case echo.DELETE: + p.Delete = operation } } + +func (r *Root) cleanUp() { + r.echo = nil + r.groups = nil + r.apis = nil + r.defs = nil +} + +// addDefinition adds definition specification and returns +// key of RawDefineDic +func (r *RawDefineDic) addDefinition(v reflect.Value) string { + exist, key := r.getKey(v) + if exist { + return key + } + + schema := &JSONSchema{ + Type: "object", + Properties: make(map[string]*JSONSchema), + } + + for i := 0; i < v.NumField(); i++ { + f := v.Type().Field(i) + name := getFieldName(f, ParamInBody) + if name == "-" { + continue + } + if f.Type == reflect.TypeOf(xml.Name{}) { + schema.handleXMLTags(f) + continue + } + sp := r.genSchema(v.Field(i)) + sp.handleXMLTags(f) + if sp.XML != nil { + sp.handleChildXMLTags(sp.XML.Name, r) + } + schema.Properties[name] = sp + + schema.handleSwaggerTags(f, name) + } + + (*r)[key] = RawDefine{ + Value: v, + Schema: schema, + } + + if schema.XML == nil { + schema.XML = &XMLSchema{} + } + if schema.XML.Name == "" { + schema.XML.Name = v.Type().Name() + } + return key +} diff --git a/tag.go b/tag.go new file mode 100644 index 0000000..f1672a9 --- /dev/null +++ b/tag.go @@ -0,0 +1,350 @@ +package echoswagger + +import ( + "reflect" + "strconv" + "strings" +) + +// getTag reports is a tag exists and it's content +// search tagName in all tags when index = -1 +func getTag(field reflect.StructField, tagName string, index int) (bool, string) { + t := field.Tag.Get(tagName) + s := strings.Split(t, ",") + + if len(s) < index+1 { + return false, "" + } + + return true, strings.TrimSpace(s[index]) +} + +func getSwaggerTags(field reflect.StructField) map[string]string { + t := field.Tag.Get("swagger") + r := make(map[string]string) + for _, v := range strings.Split(t, ",") { + leftIndex := strings.Index(v, "(") + rightIndex := strings.LastIndex(v, ")") + if leftIndex > 0 && rightIndex > leftIndex { + r[v[:leftIndex]] = v[leftIndex+1 : rightIndex] + } else { + r[v] = "" + } + } + return r +} + +func getFieldName(f reflect.StructField, in ParamInType) string { + var name string + switch in { + case ParamInQuery: + name = f.Tag.Get("query") + case ParamInFormData: + name = f.Tag.Get("form") + case ParamInBody, ParamInHeader, ParamInPath: + _, name = getTag(f, "json", 0) + } + if name != "" { + return name + } else { + return f.Name + } +} + +func (p *Parameter) handleSwaggerTags(field reflect.StructField, name string, in ParamInType) { + tags := getSwaggerTags(field) + + var collect string + if t, ok := tags["collect"]; ok && contains([]string{"csv", "ssv", "tsv", "pipes"}, t) { + collect = t + } + if t, ok := tags["desc"]; ok { + p.Description = t + } + if t, ok := tags["min"]; ok { + if m, err := strconv.ParseFloat(t, 64); err == nil { + p.Minimum = &m + } + } + if t, ok := tags["max"]; ok { + if m, err := strconv.ParseFloat(t, 64); err == nil { + p.Maximum = &m + } + } + if t, ok := tags["minLen"]; ok { + if m, err := strconv.Atoi(t); err == nil { + p.MinLength = &m + } + } + if t, ok := tags["maxLen"]; ok { + if m, err := strconv.Atoi(t); err == nil { + p.MaxLength = &m + } + } + if _, ok := tags["allowEmpty"]; ok { + p.AllowEmptyValue = true + } + if _, ok := tags["required"]; ok || in == ParamInPath { + p.Required = true + } + + convert := converter(field) + if t, ok := tags["enum"]; ok { + enums := strings.Split(t, "|") + var es []interface{} + for _, s := range enums { + v, err := convert(s) + if err != nil { + continue + } + es = append(es, v) + } + p.Enum = es + } + if t, ok := tags["default"]; ok { + v, err := convert(t) + if err == nil { + p.Default = v + } + } + + // Move part of tags in Parameter to Items + if p.Type == "array" { + items := p.Items.latest() + items.CollectionFormat = collect + items.Minimum = p.Minimum + items.Maximum = p.Maximum + items.MinLength = p.MinLength + items.MaxLength = p.MaxLength + items.Enum = p.Enum + items.Default = p.Default + p.Minimum = nil + p.Maximum = nil + p.MinLength = nil + p.MaxLength = nil + p.Enum = nil + p.Default = nil + } else { + p.CollectionFormat = collect + } +} + +func (s *JSONSchema) handleSwaggerTags(field reflect.StructField, name string) { + propSchema := s.Properties[name] + tags := getSwaggerTags(field) + + if t, ok := tags["desc"]; ok { + propSchema.Description = t + } + if t, ok := tags["min"]; ok { + if m, err := strconv.ParseFloat(t, 64); err == nil { + propSchema.Minimum = &m + } + } + if t, ok := tags["max"]; ok { + if m, err := strconv.ParseFloat(t, 64); err == nil { + propSchema.Maximum = &m + } + } + if t, ok := tags["minLen"]; ok { + if m, err := strconv.Atoi(t); err == nil { + propSchema.MinLength = &m + } + } + if t, ok := tags["maxLen"]; ok { + if m, err := strconv.Atoi(t); err == nil { + propSchema.MaxLength = &m + } + } + if _, ok := tags["required"]; ok { + s.Required = append(s.Required, name) + } + if _, ok := tags["readOnly"]; ok { + propSchema.ReadOnly = true + } + + convert := converter(field) + if t, ok := tags["enum"]; ok { + enums := strings.Split(t, "|") + var es []interface{} + for _, s := range enums { + v, err := convert(s) + if err != nil { + continue + } + es = append(es, v) + } + propSchema.Enum = es + } + if t, ok := tags["default"]; ok { + v, err := convert(t) + if err == nil { + propSchema.DefaultValue = v + } + } + + // Move part of tags in Schema to Items + if propSchema.Type == "array" { + items := propSchema.Items.latest() + items.Minimum = propSchema.Minimum + items.Maximum = propSchema.Maximum + items.MinLength = propSchema.MinLength + items.MaxLength = propSchema.MaxLength + items.Enum = propSchema.Enum + items.DefaultValue = propSchema.DefaultValue + propSchema.Minimum = nil + propSchema.Maximum = nil + propSchema.MinLength = nil + propSchema.MaxLength = nil + propSchema.Enum = nil + propSchema.DefaultValue = nil + } +} + +func (h *Header) handleSwaggerTags(field reflect.StructField, name string) { + tags := getSwaggerTags(field) + + var collect string + if t, ok := tags["collect"]; ok && contains([]string{"csv", "ssv", "tsv", "pipes"}, t) { + collect = t + } + if t, ok := tags["desc"]; ok { + h.Description = t + } + if t, ok := tags["min"]; ok { + if m, err := strconv.ParseFloat(t, 64); err == nil { + h.Minimum = &m + } + } + if t, ok := tags["max"]; ok { + if m, err := strconv.ParseFloat(t, 64); err == nil { + h.Maximum = &m + } + } + if t, ok := tags["minLen"]; ok { + if m, err := strconv.Atoi(t); err == nil { + h.MinLength = &m + } + } + if t, ok := tags["maxLen"]; ok { + if m, err := strconv.Atoi(t); err == nil { + h.MaxLength = &m + } + } + + convert := converter(field) + if t, ok := tags["enum"]; ok { + enums := strings.Split(t, "|") + var es []interface{} + for _, s := range enums { + v, err := convert(s) + if err != nil { + continue + } + es = append(es, v) + } + h.Enum = es + } + if t, ok := tags["default"]; ok { + v, err := convert(t) + if err == nil { + h.Default = v + } + } + + // Move part of tags in Header to Items + if h.Type == "array" { + items := h.Items.latest() + items.CollectionFormat = collect + items.Minimum = h.Minimum + items.Maximum = h.Maximum + items.MinLength = h.MinLength + items.MaxLength = h.MaxLength + items.Enum = h.Enum + items.Default = h.Default + h.Minimum = nil + h.Maximum = nil + h.MinLength = nil + h.MaxLength = nil + h.Enum = nil + h.Default = nil + } else { + h.CollectionFormat = collect + } +} + +func (t *Items) latest() *Items { + if t.Items != nil { + return t.Items.latest() + } + return t +} + +func (s *JSONSchema) latest() *JSONSchema { + if s.Items != nil { + return s.Items.latest() + } + return s +} + +// Not support nested elements tag eg:"a>b>c" +// Not support tags: ",chardata", ",cdata", ",comment" +// Not support embedded structure with tag ",innerxml" +// Only support nested elements tag in array type eg:"Name []string `xml:"names>name"`" +func (s *JSONSchema) handleXMLTags(f reflect.StructField) { + b, a := getTag(f, "xml", 1) + if b && contains([]string{"chardata", "cdata", "comment"}, a) { + return + } + + if b, t := getTag(f, "xml", 0); b { + if t == "-" || s.Ref != "" { + return + } else if t == "" { + t = f.Name + } + + if s.XML == nil { + s.XML = &XMLSchema{} + } + if a == "attr" { + s.XML.Attribute = t + } else { + s.XML.Name = t + } + } +} + +func (s *JSONSchema) handleChildXMLTags(rest string, r *RawDefineDic) { + if rest == "" { + return + } + + if s.Items == nil && s.Ref == "" { + if s.XML == nil { + s.XML = &XMLSchema{} + } + s.XML.Name = rest + } else if s.Ref != "" { + key := s.Ref[len(DefPrefix):] + if sc, ok := (*r)[key]; ok && sc.Schema != nil { + if sc.Schema.XML == nil { + sc.Schema.XML = &XMLSchema{} + } + sc.Schema.XML.Name = rest + } + } else { + if s.XML == nil { + s.XML = &XMLSchema{} + } + s.XML.Wrapped = true + i := strings.Index(rest, ">") + if i <= 0 { + s.XML.Name = rest + } else { + s.XML.Name = rest[:i] + rest = rest[i+1:] + s.Items.handleChildXMLTags(rest, r) + } + } +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..9fc5f11 --- /dev/null +++ b/utils.go @@ -0,0 +1,76 @@ +package echoswagger + +import "reflect" + +func contains(list []string, s string) bool { + for _, t := range list { + if t == s { + return true + } + } + return false +} + +func containsMap(list []map[string][]string, m map[string][]string) bool { +LoopMaps: + for _, t := range list { + if len(t) != len(m) { + continue + } + for k, v := range t { + if mv, ok := m[k]; !ok || !equals(mv, v) { + continue LoopMaps + } + } + return true + } + return false +} + +func equals(a []string, b []string) bool { + if len(a) != len(b) { + return false + } + for _, t := range a { + if !contains(b, t) { + return false + } + } + return true +} + +func indirect(v reflect.Value) reflect.Value { + t := v.Type() + v = reflect.Indirect(v) + if !v.IsValid() { + v = reflect.New(t) + } + if v.Kind() == reflect.Ptr { + return indirect(v) + } + return v +} + +func indirectValue(p interface{}) reflect.Value { + v := reflect.ValueOf(p) +LoopValue: + v = reflect.Indirect(v) + if !v.IsValid() { + v = reflect.New(reflect.TypeOf(p)) + } + if v.Kind() == reflect.Ptr { + goto LoopValue + } + // TODO 遍历所有子项,为Invalid初始化Value + return v +} + +func indirectType(p interface{}) reflect.Type { + t := reflect.TypeOf(p) +LoopType: + if t.Kind() == reflect.Ptr { + t = t.Elem() + goto LoopType + } + return t +} diff --git a/validator.go b/validator.go new file mode 100644 index 0000000..12a017f --- /dev/null +++ b/validator.go @@ -0,0 +1,114 @@ +package echoswagger + +import ( + "net/url" + "reflect" + "regexp" + "time" +) + +var emailRegexp = regexp.MustCompile("^[a-zA-Z0-9.!#$%&'*+/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$") + +func isValidEmail(str string) bool { + return emailRegexp.MatchString(str) +} + +// See: https://github.com/swagger-api/swagger-js/blob/7414ad062ba9b6d9cc397c72e7561ec775b35a9f/lib/shred/parseUri.js#L28 +func isValidURL(str string) bool { + if _, err := url.ParseRequestURI(str); err != nil { + return false + } + return true +} + +func isValidParam(t reflect.Type, nest, inner bool) bool { + if t == nil { + return false + } + switch t.Kind() { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, + reflect.Float32, reflect.Float64, reflect.String: + if !nest || (nest && inner) { + return true + } + case reflect.Array, reflect.Slice: + return isValidParam(t.Elem(), nest, true) + case reflect.Ptr: + return isValidParam(t.Elem(), nest, inner) + case reflect.Struct: + if t == reflect.TypeOf(time.Time{}) { + return true + } else if !inner { + for i := 0; i < t.NumField(); i++ { + if !isValidParam(t.Field(i).Type, nest, true) { + return false + } + } + return true + } + } + return false +} + +// isValidSchema reports a type is valid for body param. +// valid case: +// 1. Struct +// 2. Struct{ A int64 } +// 3. *[]Struct +// 4. [][]Struct +// 5. []Struct{ A []Struct } +// 6. []Struct{ A Map[string]string } +// 7. *Struct{ A []Map[int64]Struct } +// 8. Map[string]string +// 9. []int64 +// invalid case: +// 1. interface{} +// 2. Map[Struct]string +func isValidSchema(t reflect.Type, inner bool) bool { + if t == nil { + return false + } + switch t.Kind() { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, + reflect.Float32, reflect.Float64, reflect.String: + return true + case reflect.Array, reflect.Slice: + return isValidSchema(t.Elem(), inner) + case reflect.Map: + return isBasicType(t.Key()) && isValidSchema(t.Elem(), true) + case reflect.Ptr: + return isValidSchema(t.Elem(), inner) + case reflect.Struct: + if t == reflect.TypeOf(time.Time{}) { + return true + } + for i := 0; i < t.NumField(); i++ { + if !isValidSchema(t.Field(i).Type, true) { + return false + } + } + return true + } + return false +} + +func isBasicType(t reflect.Type) bool { + if t == nil { + return false + } + switch t.Kind() { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, + reflect.Float32, reflect.Float64, reflect.String: + return true + case reflect.Ptr: + return isBasicType(t.Elem()) + case reflect.Struct: + if t == reflect.TypeOf(time.Time{}) { + return true + } + } + return false +} diff --git a/wrapper.go b/wrapper.go index cfad631..f56903f 100644 --- a/wrapper.go +++ b/wrapper.go @@ -2,24 +2,172 @@ package echoswagger import ( "reflect" + "strconv" + "sync" "github.com/labstack/echo" ) -type RawDefine struct { - Value reflect.Value - Schema *JSONSchema +/* +TODO: +1.pattern +2.opreationId 重复判断 + +Notice: +1.不会对Email和URL进行验证,因为不影响页面的正常显示 +2.只支持对应于SwaggerUI页面的Schema,不支持sw、sww等协议 +3.SetSecurity/SetSecurityWithScope 传多个参数表示Security之间是AND关系;多次调用SetSecurity/SetSecurityWithScope Security之间是OR关系 +4.只支持基本类型的Map Key +*/ + +type ApiRoot interface { + // ApiGroup creates ApiGroup. Use this instead of Echo#ApiGroup. + Group(name, prefix string, m ...echo.MiddlewareFunc) ApiGroup + + // SetRequestContentType sets request content types. + SetRequestContentType(types ...string) ApiRoot + + // SetResponseContentType sets response content types. + SetResponseContentType(types ...string) ApiRoot + + // SetExternalDocs sets external docs. + SetExternalDocs(desc, url string) ApiRoot + + // AddSecurityBasic adds `SecurityDefinition` with type basic. + AddSecurityBasic(name, desc string) ApiRoot + + // AddSecurityAPIKey adds `SecurityDefinition` with type apikey. + AddSecurityAPIKey(name, desc string, in SecurityInType) ApiRoot + + // AddSecurityOAuth2 adds `SecurityDefinition` with type oauth2. + AddSecurityOAuth2(name, desc string, flow OAuth2FlowType, authorizationUrl, tokenUrl string, scopes map[string]string) ApiRoot + + // GetRaw returns raw `Swagger`. Only special case should use. + GetRaw() *Swagger + + // SetRaw sets raw `Swagger` to ApiRoot. Only special case should use. + SetRaw(s *Swagger) ApiRoot + + // Echo returns the embeded echo instance + Echo() *echo.Echo +} + +type ApiGroup interface { + // GET overrides `Echo#GET()` for sub-routes within the ApiGroup. + GET(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) Api + + // POST overrides `Echo#POST()` for sub-routes within the ApiGroup. + POST(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) Api + + // PUT overrides `Echo#PUT()` for sub-routes within the ApiGroup. + PUT(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) Api + + // DELETE overrides `Echo#DELETE()` for sub-routes within the ApiGroup. + DELETE(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) Api + + // SetDescription sets description for ApiGroup. + SetDescription(desc string) ApiGroup + + // SetExternalDocs sets external docs for ApiGroup. + SetExternalDocs(desc, url string) ApiGroup + + // SetSecurity sets Security for all operations within the ApiGroup + // which names are reigisters by AddSecurity... functions. + SetSecurity(names ...string) ApiGroup + + // SetSecurity sets Security with scopes for all operations + // within the ApiGroup which names are reigisters + // by AddSecurity... functions. + // Should only use when Security type is oauth2. + SetSecurityWithScope(s map[string][]string) ApiGroup + + // EchoGroup returns `*echo.Group` within the ApiGroup. + EchoGroup() *echo.Group +} + +type Api interface { + // AddParamPath adds path parameter. + AddParamPath(p interface{}, name, desc string) Api + + // AddParamPathNested adds path parameters nested in p. + AddParamPathNested(p interface{}) Api + + // AddParamQuery adds query parameter. + AddParamQuery(p interface{}, name, desc string, required bool) Api + + // AddParamQueryNested adds query parameters nested in p. + AddParamQueryNested(p interface{}) Api + + // AddParamForm adds formData parameter. + AddParamForm(p interface{}, name, desc string, required bool) Api + + // AddParamFormNested adds formData parameters nested in p. + AddParamFormNested(p interface{}) Api + + // AddParamHeader adds header parameter. + AddParamHeader(p interface{}, name, desc string, required bool) Api + + // AddParamHeaderNested adds header parameters nested in p. + AddParamHeaderNested(p interface{}) Api + + // AddParamBody adds body parameter. + AddParamBody(p interface{}, name, desc string, required bool) Api + + // AddParamBody adds file parameter. + AddParamFile(name, desc string, required bool) Api + + // AddResponse adds response for Api. + AddResponse(code int, desc string, schema interface{}, header interface{}) Api + + // SetResponseContentType sets request content types. + SetRequestContentType(types ...string) Api + + // SetResponseContentType sets response content types. + SetResponseContentType(types ...string) Api + + // SetOperationId sets operationId + SetOperationId(id string) Api + + // SetDescription marks Api as deprecated. + SetDeprecated() Api + + // SetDescription sets description. + SetDescription(desc string) Api + + // SetExternalDocs sets external docs. + SetExternalDocs(desc, url string) Api + + // SetExternalDocs sets summary. + SetSummary(summary string) Api + + // SetSecurity sets Security which names are reigisters + // by AddSecurity... functions. + SetSecurity(names ...string) Api + + // SetSecurity sets Security for Api which names are + // reigisters by AddSecurity... functions. + // Should only use when Security type is oauth2. + SetSecurityWithScope(s map[string][]string) Api + + // TODO return echo.Router +} + +type routers struct { + apis []api + defs *RawDefineDic } type Root struct { - apis []api + routers spec *Swagger echo *echo.Echo groups []group + once sync.Once + err error } type group struct { - apis []api + routers echoGroup *echo.Group security []map[string][]string tag Tag @@ -27,12 +175,15 @@ type group struct { type api struct { route *echo.Route + defs *RawDefineDic security []map[string][]string method string operation Operation } -func New(e *echo.Echo, basePath, docPath string, i *Info) *Root { +// New creates ApiRoot instance. +// Multiple ApiRoot are allowed in one project. +func New(e *echo.Echo, basePath, docPath string, i *Info) ApiRoot { if e == nil { panic("echoswagger: invalid Echo instance") } @@ -45,6 +196,12 @@ func New(e *echo.Echo, basePath, docPath string, i *Info) *Root { } specPath := docPath + connector + "swagger.json" + if i == nil { + i = &Info{ + Title: "Project APIs", + } + } + defs := make(RawDefineDic) r := &Root{ echo: e, spec: &Swagger{ @@ -53,9 +210,293 @@ func New(e *echo.Echo, basePath, docPath string, i *Info) *Root { BasePath: basePath, Definitions: make(map[string]*JSONSchema), }, + routers: routers{ + defs: &defs, + }, } e.GET(docPath, r.docHandler(specPath)) e.GET(specPath, r.Spec) return r } + +func (r *Root) Group(name, prefix string, m ...echo.MiddlewareFunc) ApiGroup { + if name == "" { + panic("echowagger: invalid name of ApiGroup") + } + echoGroup := r.echo.Group(prefix, m...) + group := group{ + echoGroup: echoGroup, + routers: routers{ + defs: r.defs, + }, + } + var counter int +LoopTags: + for _, t := range r.groups { + if t.tag.Name == name { + if counter > 0 { + name = name[:len(name)-2] + } + counter++ + name += "_" + strconv.Itoa(counter) + goto LoopTags + } + } + group.tag = Tag{Name: name} + r.groups = append(r.groups, group) + return &r.groups[len(r.groups)-1] +} + +func (r *Root) SetRequestContentType(types ...string) ApiRoot { + r.spec.Consumes = types + return r +} + +func (r *Root) SetResponseContentType(types ...string) ApiRoot { + r.spec.Produces = types + return r +} + +func (r *Root) SetExternalDocs(desc, url string) ApiRoot { + r.spec.ExternalDocs = &ExternalDocs{ + Description: desc, + URL: url, + } + return r +} + +func (r *Root) AddSecurityBasic(name, desc string) ApiRoot { + if !r.checkSecurity(name) { + return r + } + sd := &SecurityDefinition{ + Type: string(SecurityBasic), + Description: desc, + } + r.spec.SecurityDefinitions[name] = sd + return r +} + +func (r *Root) AddSecurityAPIKey(name, desc string, in SecurityInType) ApiRoot { + if !r.checkSecurity(name) { + return r + } + sd := &SecurityDefinition{ + Type: string(SecurityAPIKey), + Description: desc, + Name: name, + In: string(in), + } + r.spec.SecurityDefinitions[name] = sd + return r +} + +func (r *Root) AddSecurityOAuth2(name, desc string, flow OAuth2FlowType, authorizationUrl, tokenUrl string, scopes map[string]string) ApiRoot { + if !r.checkSecurity(name) { + return r + } + sd := &SecurityDefinition{ + Type: string(SecurityOAuth2), + Description: desc, + Flow: string(flow), + AuthorizationURL: authorizationUrl, + TokenURL: tokenUrl, + Scopes: scopes, + } + r.spec.SecurityDefinitions[name] = sd + return r +} + +func (r *Root) GetRaw() *Swagger { + return r.spec +} + +func (r *Root) SetRaw(s *Swagger) ApiRoot { + r.spec = s + return r +} + +func (r *Root) Echo() *echo.Echo { + return r.echo +} + +func (g *group) GET(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) Api { + a := g.appendRoute(echo.GET, g.echoGroup.GET(path, h, m...)) + a.operation.Tags = []string{g.tag.Name} + return a +} + +func (g *group) POST(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) Api { + a := g.appendRoute(echo.POST, g.echoGroup.POST(path, h, m...)) + a.operation.Tags = []string{g.tag.Name} + return a +} + +func (g *group) PUT(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) Api { + a := g.appendRoute(echo.PUT, g.echoGroup.PUT(path, h, m...)) + a.operation.Tags = []string{g.tag.Name} + return a +} + +func (g *group) DELETE(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) Api { + a := g.appendRoute(echo.DELETE, g.echoGroup.DELETE(path, h, m...)) + a.operation.Tags = []string{g.tag.Name} + return a +} + +func (g *group) SetDescription(desc string) ApiGroup { + g.tag.Description = desc + return g +} + +func (g *group) SetExternalDocs(desc, url string) ApiGroup { + g.tag.ExternalDocs = &ExternalDocs{ + Description: desc, + URL: url, + } + return g +} + +func (g *group) SetSecurity(names ...string) ApiGroup { + if len(names) == 0 { + return g + } + g.security = setSecurity(g.security, names...) + return g +} + +func (g *group) SetSecurityWithScope(s map[string][]string) ApiGroup { + g.security = setSecurityWithScope(g.security, s) + return g +} + +func (g *group) EchoGroup() *echo.Group { + return g.echoGroup +} + +func (a *api) AddParamPath(p interface{}, name, desc string) Api { + return a.addParams(p, ParamInPath, name, desc, true, false) +} + +func (a *api) AddParamPathNested(p interface{}) Api { + return a.addParams(p, ParamInPath, "", "", true, true) +} + +func (a *api) AddParamQuery(p interface{}, name, desc string, required bool) Api { + return a.addParams(p, ParamInQuery, name, desc, required, false) +} + +func (a *api) AddParamQueryNested(p interface{}) Api { + return a.addParams(p, ParamInQuery, "", "", false, true) +} + +func (a *api) AddParamForm(p interface{}, name, desc string, required bool) Api { + return a.addParams(p, ParamInFormData, name, desc, required, false) +} + +func (a *api) AddParamFormNested(p interface{}) Api { + return a.addParams(p, ParamInFormData, "", "", false, true) +} + +func (a *api) AddParamHeader(p interface{}, name, desc string, required bool) Api { + return a.addParams(p, ParamInHeader, name, desc, required, false) +} + +func (a *api) AddParamHeaderNested(p interface{}) Api { + return a.addParams(p, ParamInHeader, "", "", false, true) +} + +func (a *api) AddParamBody(p interface{}, name, desc string, required bool) Api { + return a.addBodyParams(p, name, desc, required) +} + +func (a *api) AddParamFile(name, desc string, required bool) Api { + name = a.operation.rename(name) + a.operation.Parameters = append(a.operation.Parameters, &Parameter{ + Name: name, + In: string(ParamInFormData), + Description: desc, + Required: required, + Type: "file", + }) + return a +} + +// Notice: header must be nested in a struct. +func (a *api) AddResponse(code int, desc string, schema interface{}, header interface{}) Api { + r := &Response{ + Description: desc, + } + + st := reflect.TypeOf(schema) + if st != nil { + if !isValidSchema(st, false) { + panic("echoswagger: invalid response schema") + } + r.Schema = a.defs.genSchema(reflect.ValueOf(schema)) + } + + ht := reflect.TypeOf(header) + if ht != nil { + if !isValidParam(reflect.TypeOf(header), true, false) { + panic("echoswagger: invalid response header") + } + r.Headers = a.genHeader(reflect.ValueOf(header)) + } + + cstr := strconv.Itoa(code) + a.operation.Responses[cstr] = r + return a +} + +func (a *api) SetRequestContentType(types ...string) Api { + a.operation.Consumes = types + return a +} + +func (a *api) SetResponseContentType(types ...string) Api { + a.operation.Produces = types + return a +} + +func (a *api) SetOperationId(id string) Api { + a.operation.OperationID = id + return a +} + +func (a *api) SetDeprecated() Api { + a.operation.Deprecated = true + return a +} + +func (a *api) SetDescription(desc string) Api { + a.operation.Description = desc + return a +} + +func (a *api) SetExternalDocs(desc, url string) Api { + a.operation.ExternalDocs = &ExternalDocs{ + Description: desc, + URL: url, + } + return a +} + +func (a *api) SetSummary(summary string) Api { + a.operation.Summary = summary + return a +} + +func (a *api) SetSecurity(names ...string) Api { + if len(names) == 0 { + return a + } + a.security = setSecurity(a.security, names...) + return a +} + +func (a *api) SetSecurityWithScope(s map[string][]string) Api { + a.security = setSecurityWithScope(a.security, s) + return a +}