gophish/middleware/middleware_test.go

165 lines
4.6 KiB
Go
Raw Normal View History

package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/gophish/gophish/config"
ctx "github.com/gophish/gophish/context"
"github.com/gophish/gophish/models"
)
var successHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("success"))
})
type testContext struct {
apiKey string
}
func setupTest(t *testing.T) *testContext {
conf := &config.Config{
DBName: "sqlite3",
DBPath: ":memory:",
MigrationsPath: "../db/db_sqlite3/migrations/",
}
err := models.Setup(conf)
if err != nil {
t.Fatalf("Failed creating database: %v", err)
}
// Get the API key to use for these tests
u, err := models.GetUser(1)
if err != nil {
t.Fatalf("error getting user: %v", err)
}
ctx := &testContext{}
ctx.apiKey = u.ApiKey
return ctx
}
// MiddlewarePermissionTest maps an expected HTTP Method to an expected HTTP
// status code
type MiddlewarePermissionTest map[string]int
// TestEnforceViewOnly ensures that only users with the ModifyObjects
// permission have the ability to send non-GET requests.
func TestEnforceViewOnly(t *testing.T) {
setupTest(t)
permissionTests := map[string]MiddlewarePermissionTest{
models.RoleAdmin: MiddlewarePermissionTest{
http.MethodGet: http.StatusOK,
http.MethodHead: http.StatusOK,
http.MethodOptions: http.StatusOK,
http.MethodPost: http.StatusOK,
http.MethodPut: http.StatusOK,
http.MethodDelete: http.StatusOK,
},
models.RoleUser: MiddlewarePermissionTest{
http.MethodGet: http.StatusOK,
http.MethodHead: http.StatusOK,
http.MethodOptions: http.StatusOK,
http.MethodPost: http.StatusOK,
http.MethodPut: http.StatusOK,
http.MethodDelete: http.StatusOK,
},
}
for r, checks := range permissionTests {
role, err := models.GetRoleBySlug(r)
if err != nil {
t.Fatalf("error getting role by slug: %v", err)
}
for method, expected := range checks {
req := httptest.NewRequest(method, "/", nil)
response := httptest.NewRecorder()
req = ctx.Set(req, "user", models.User{
Role: role,
RoleID: role.ID,
})
EnforceViewOnly(successHandler).ServeHTTP(response, req)
got := response.Code
if got != expected {
t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
}
}
}
}
func TestRequirePermission(t *testing.T) {
setupTest(t)
middleware := RequirePermission(models.PermissionModifySystem)
handler := middleware(successHandler)
permissionTests := map[string]int{
models.RoleUser: http.StatusForbidden,
models.RoleAdmin: http.StatusOK,
}
for role, expected := range permissionTests {
req := httptest.NewRequest(http.MethodGet, "/", nil)
response := httptest.NewRecorder()
// Test that with the requested permission, the request succeeds
role, err := models.GetRoleBySlug(role)
if err != nil {
t.Fatalf("error getting role by slug: %v", err)
}
req = ctx.Set(req, "user", models.User{
Role: role,
RoleID: role.ID,
})
handler.ServeHTTP(response, req)
got := response.Code
if got != expected {
t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
}
}
}
func TestRequireAPIKey(t *testing.T) {
setupTest(t)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Content-Type", "application/json")
response := httptest.NewRecorder()
// Test that making a request without an API key is denied
RequireAPIKey(successHandler).ServeHTTP(response, req)
expected := http.StatusUnauthorized
got := response.Code
if got != expected {
t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
}
}
func TestInvalidAPIKey(t *testing.T) {
setupTest(t)
req := httptest.NewRequest(http.MethodGet, "/", nil)
query := req.URL.Query()
query.Set("api_key", "bogus-api-key")
req.URL.RawQuery = query.Encode()
req.Header.Set("Content-Type", "application/json")
response := httptest.NewRecorder()
RequireAPIKey(successHandler).ServeHTTP(response, req)
expected := http.StatusUnauthorized
got := response.Code
if got != expected {
t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
}
}
func TestBearerToken(t *testing.T) {
testCtx := setupTest(t)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", testCtx.apiKey))
req.Header.Set("Content-Type", "application/json")
response := httptest.NewRecorder()
RequireAPIKey(successHandler).ServeHTTP(response, req)
expected := http.StatusOK
got := response.Code
if got != expected {
t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
}
}