Merge pull request #16 from elvinchan/15

#15 Extract host from referer of spec's request header
This commit is contained in:
陈文强 2018-11-03 16:14:48 +08:00 committed by GitHub
commit b36b795282
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 1 deletions

View File

@ -3,6 +3,7 @@ package echoswagger
import ( import (
"encoding/xml" "encoding/xml"
"net/http" "net/http"
"net/url"
"reflect" "reflect"
"github.com/labstack/echo" "github.com/labstack/echo"
@ -21,13 +22,17 @@ func (r *Root) Spec(c echo.Context) error {
if r.err != nil { if r.err != nil {
return c.String(http.StatusInternalServerError, r.err.Error()) return c.String(http.StatusInternalServerError, r.err.Error())
} }
if uri, err := url.ParseRequestURI(c.Request().Referer()); err == nil {
r.spec.Host = uri.Host
} else {
r.spec.Host = c.Request().Host
}
return c.JSON(http.StatusOK, r.spec) return c.JSON(http.StatusOK, r.spec)
} }
func (r *Root) genSpec(c echo.Context) error { func (r *Root) genSpec(c echo.Context) error {
r.spec.Swagger = SwaggerVersion r.spec.Swagger = SwaggerVersion
r.spec.Paths = make(map[string]interface{}) r.spec.Paths = make(map[string]interface{})
r.spec.Host = c.Request().Host
for i := range r.groups { for i := range r.groups {
group := &r.groups[i] group := &r.groups[i]

View File

@ -1,6 +1,7 @@
package echoswagger package echoswagger
import ( import (
"encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
@ -107,6 +108,52 @@ func TestSpec(t *testing.T) {
}) })
} }
func TestRefererHost(t *testing.T) {
tests := []struct {
name, referer, host string
}{
{
referer: "http://localhost:1323/doc",
host: "localhost:1323",
name: "A",
},
{
referer: "1/doc",
host: "127.0.0.1",
name: "B",
},
{
referer: "http://user:pass@github.com",
host: "github.com",
name: "C",
},
{
referer: "https://www.github.com?q=1",
host: "www.github.com",
name: "D",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := prepareApiRoot()
e := r.(*Root).echo
req := httptest.NewRequest(echo.GET, "http://127.0.0.1/doc/swagger.json", nil)
req.Header.Add("referer", tt.referer)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if assert.NoError(t, r.(*Root).Spec(c)) {
assert.Equal(t, http.StatusOK, rec.Code)
var v struct {
Host string `json:"host"`
}
err := json.Unmarshal(rec.Body.Bytes(), &v)
assert.NoError(t, err)
assert.Equal(t, tt.host, v.Host)
}
})
}
}
func TestAddDefinition(t *testing.T) { func TestAddDefinition(t *testing.T) {
type DA struct { type DA struct {
Name string Name string