diff --git a/controllers/route.go b/controllers/route.go index 5d2d5f3e..9109ff85 100644 --- a/controllers/route.go +++ b/controllers/route.go @@ -155,7 +155,7 @@ func (as *AdminServer) registerRoutes() { csrf.FieldName("csrf_token"), csrf.Secure(as.config.UseTLS)) adminHandler := csrfHandler(router) - adminHandler = mid.Use(adminHandler.ServeHTTP, mid.CSRFExceptions, mid.GetContext) + adminHandler = mid.Use(adminHandler.ServeHTTP, mid.CSRFExceptions, mid.GetContext, mid.ApplySecurityHeaders) // Setup GZIP compression gzipWrapper, _ := gziphandler.NewGzipLevelHandler(gzip.BestCompression) diff --git a/middleware/middleware.go b/middleware/middleware.go index d4ff2019..85a32b05 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -176,6 +176,17 @@ func RequirePermission(perm string) func(http.Handler) http.HandlerFunc { } } +// ApplySecurityHeaders applies various security headers according to best- +// practices. +func ApplySecurityHeaders(next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + csp := "frame-ancestors 'none';" + w.Header().Set("Content-Security-Policy", csp) + w.Header().Set("X-Frame-Options", "DENY") + next.ServeHTTP(w, r) + } +} + // JSONError returns an error in JSON format with the given // status code and message func JSONError(w http.ResponseWriter, c int, m string) { diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 85915a2a..d7f15786 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -181,3 +181,19 @@ func TestPasswordResetRequired(t *testing.T) { t.Fatalf("incorrect location header received. expected %s got %s", expectedLocation, gotLocation) } } + +func TestApplySecurityHeaders(t *testing.T) { + expected := map[string]string{ + "Content-Security-Policy": "frame-ancestors 'none';", + "X-Frame-Options": "DENY", + } + req := httptest.NewRequest(http.MethodGet, "/", nil) + response := httptest.NewRecorder() + ApplySecurityHeaders(successHandler).ServeHTTP(response, req) + for header, value := range expected { + got := response.Header().Get(header) + if got != value { + t.Fatalf("incorrect security header received for %s: expected %s got %s", header, value, got) + } + } +}