diff --git a/controllers/api_test.go b/controllers/api_test.go index 31d5e627..40758413 100644 --- a/controllers/api_test.go +++ b/controllers/api_test.go @@ -109,6 +109,23 @@ func (s *ControllersSuite) TestRequireAPIKey() { s.Equal(resp.StatusCode, http.StatusBadRequest) } +func (s *ControllersSuite) TestInvalidAPIKey() { + resp, err := http.Get(fmt.Sprintf("%s/api/groups/?api_key=%s", as.URL, "bogus-api-key")) + s.Nil(err) + defer resp.Body.Close() + s.Equal(resp.StatusCode, http.StatusBadRequest) +} + +func (s *ControllersSuite) TestBearerToken() { + req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/groups/", as.URL), nil) + s.Nil(err) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.ApiKey)) + resp, err := http.DefaultClient.Do(req) + s.Nil(err) + defer resp.Body.Close() + s.Equal(resp.StatusCode, http.StatusOK) +} + func (s *ControllersSuite) TestSiteImportBaseHref() { h := "" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/middleware/middleware.go b/middleware/middleware.go index ef7a0d99..d0e0187b 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -62,8 +62,6 @@ func GetContext(handler http.Handler) http.HandlerFunc { func RequireAPIKey(handler http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - ak := r.Form.Get("api_key") w.Header().Set("Access-Control-Allow-Origin", "*") if r.Method == "OPTIONS" { w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") @@ -71,19 +69,29 @@ func RequireAPIKey(handler http.Handler) http.HandlerFunc { w.Header().Set("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept") return } + r.ParseForm() + ak := r.Form.Get("api_key") + // If we can't get the API key, we'll also check for the + // Authorization Bearer token + if ak == "" { + tokens, ok := r.Header["Authorization"] + if ok && len(tokens) >= 1 { + ak = tokens[0] + ak = strings.TrimPrefix(ak, "Bearer ") + } + } if ak == "" { JSONError(w, 400, "API Key not set") return - } else { - u, err := models.GetUserByAPIKey(ak) - if err != nil { - JSONError(w, 400, "Invalid API Key") - return - } - r = ctx.Set(r, "user_id", u.Id) - r = ctx.Set(r, "api_key", ak) - handler.ServeHTTP(w, r) } + u, err := models.GetUserByAPIKey(ak) + if err != nil { + JSONError(w, 400, "Invalid API Key") + return + } + r = ctx.Set(r, "user_id", u.Id) + r = ctx.Set(r, "api_key", ak) + handler.ServeHTTP(w, r) } }