From 1ac55389b17a07669c7f45331f07bc7c6d5d8cb9 Mon Sep 17 00:00:00 2001 From: ElvinChan Date: Sat, 3 Nov 2018 09:51:41 +0800 Subject: [PATCH] #15 Extract host from referer of spec's request header --- spec.go | 7 ++++++- spec_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/spec.go b/spec.go index cd60e94..86477bb 100644 --- a/spec.go +++ b/spec.go @@ -3,6 +3,7 @@ package echoswagger import ( "encoding/xml" "net/http" + "net/url" "reflect" "github.com/labstack/echo" @@ -21,13 +22,17 @@ func (r *Root) Spec(c echo.Context) error { if r.err != nil { return c.String(http.StatusInternalServerError, r.err.Error()) } + if uri, err := url.ParseRequestURI(c.Request().Referer()); err == nil { + r.spec.Host = uri.Host + } else { + r.spec.Host = c.Request().Host + } return c.JSON(http.StatusOK, r.spec) } func (r *Root) genSpec(c echo.Context) error { r.spec.Swagger = SwaggerVersion r.spec.Paths = make(map[string]interface{}) - r.spec.Host = c.Request().Host for i := range r.groups { group := &r.groups[i] diff --git a/spec_test.go b/spec_test.go index df65ed1..d2607d9 100644 --- a/spec_test.go +++ b/spec_test.go @@ -1,6 +1,7 @@ package echoswagger import ( + "encoding/json" "net/http" "net/http/httptest" "reflect" @@ -107,6 +108,52 @@ func TestSpec(t *testing.T) { }) } +func TestRefererHost(t *testing.T) { + tests := []struct { + name, referer, host string + }{ + { + referer: "http://localhost:1323/doc", + host: "localhost:1323", + name: "A", + }, + { + referer: "1/doc", + host: "127.0.0.1", + name: "B", + }, + { + referer: "http://user:pass@github.com", + host: "github.com", + name: "C", + }, + { + referer: "https://www.github.com?q=1", + host: "www.github.com", + name: "D", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := prepareApiRoot() + e := r.(*Root).echo + req := httptest.NewRequest(echo.GET, "http://127.0.0.1/doc/swagger.json", nil) + req.Header.Add("referer", tt.referer) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if assert.NoError(t, r.(*Root).Spec(c)) { + assert.Equal(t, http.StatusOK, rec.Code) + var v struct { + Host string `json:"host"` + } + err := json.Unmarshal(rec.Body.Bytes(), &v) + assert.NoError(t, err) + assert.Equal(t, tt.host, v.Host) + } + }) + } +} + func TestAddDefinition(t *testing.T) { type DA struct { Name string