mirror of https://github.com/gophish/gophish
205 lines
6.5 KiB
Go
205 lines
6.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
|
|
ctx "github.com/gophish/gophish/context"
|
|
"github.com/gophish/gophish/models"
|
|
"github.com/gorilla/csrf"
|
|
)
|
|
|
|
// CSRFExemptPrefixes are a list of routes that are exempt from CSRF protection
|
|
var CSRFExemptPrefixes = []string{
|
|
"/api",
|
|
}
|
|
|
|
// CSRFExceptions is a middleware that prevents CSRF checks on routes listed in
|
|
// CSRFExemptPrefixes.
|
|
func CSRFExceptions(handler http.Handler) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
for _, prefix := range CSRFExemptPrefixes {
|
|
if strings.HasPrefix(r.URL.Path, prefix) {
|
|
r = csrf.UnsafeSkipCheck(r)
|
|
break
|
|
}
|
|
}
|
|
handler.ServeHTTP(w, r)
|
|
}
|
|
}
|
|
|
|
// Use allows us to stack middleware to process the request
|
|
// Example taken from https://github.com/gorilla/mux/pull/36#issuecomment-25849172
|
|
func Use(handler http.HandlerFunc, mid ...func(http.Handler) http.HandlerFunc) http.HandlerFunc {
|
|
for _, m := range mid {
|
|
handler = m(handler)
|
|
}
|
|
return handler
|
|
}
|
|
|
|
// GetContext wraps each request in a function which fills in the context for a given request.
|
|
// This includes setting the User and Session keys and values as necessary for use in later functions.
|
|
func GetContext(handler http.Handler) http.HandlerFunc {
|
|
// Set the context here
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Parse the request form
|
|
err := r.ParseForm()
|
|
if err != nil {
|
|
http.Error(w, "Error parsing request", http.StatusInternalServerError)
|
|
}
|
|
// Set the context appropriately here.
|
|
// Set the session
|
|
session, _ := Store.Get(r, "gophish")
|
|
// Put the session in the context so that we can
|
|
// reuse the values in different handlers
|
|
r = ctx.Set(r, "session", session)
|
|
if id, ok := session.Values["id"]; ok {
|
|
u, err := models.GetUser(id.(int64))
|
|
if err != nil {
|
|
r = ctx.Set(r, "user", nil)
|
|
} else {
|
|
r = ctx.Set(r, "user", u)
|
|
}
|
|
} else {
|
|
r = ctx.Set(r, "user", nil)
|
|
}
|
|
handler.ServeHTTP(w, r)
|
|
// Remove context contents
|
|
ctx.Clear(r)
|
|
}
|
|
}
|
|
|
|
// RequireAPIKey ensures that a valid login cookie or API key is set (either
|
|
// the api_key GET parameter, or a Bearer token)
|
|
func RequireAPIKey(handler http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
if r.Method == "OPTIONS" {
|
|
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
|
|
w.Header().Set("Access-Control-Max-Age", "1000")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept")
|
|
return
|
|
}
|
|
r.ParseForm()
|
|
ak := r.Form.Get("api_key")
|
|
// If we can't get the API key, we'll also check for the
|
|
// Authorization Bearer token
|
|
if ak == "" {
|
|
tokens, ok := r.Header["Authorization"]
|
|
if ok && len(tokens) >= 1 {
|
|
ak = tokens[0]
|
|
ak = strings.TrimPrefix(ak, "Bearer ")
|
|
}
|
|
}
|
|
// If we can't get the API key, we'll also check if user is logged in
|
|
// via the web interface
|
|
if ak == "" {
|
|
if u := ctx.Get(r, "user"); u != nil {
|
|
ak = u.(models.User).ApiKey
|
|
}
|
|
}
|
|
if ak == "" {
|
|
JSONError(w, http.StatusUnauthorized, "Not logged in") //API Key not set
|
|
return
|
|
}
|
|
u, err := models.GetUserByAPIKey(ak)
|
|
if err != nil {
|
|
JSONError(w, http.StatusUnauthorized, "Invalid API Key")
|
|
return
|
|
}
|
|
r = ctx.Set(r, "user", u)
|
|
r = ctx.Set(r, "user_id", u.Id)
|
|
r = ctx.Set(r, "api_key", ak)
|
|
handler.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// RequireLogin checks to see if the user is currently logged in.
|
|
// If not, the function returns a 302 redirect to the login page.
|
|
func RequireLogin(handler http.Handler) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
if u := ctx.Get(r, "user"); u != nil {
|
|
// If a password change is required for the user, then redirect them
|
|
// to the login page
|
|
currentUser := u.(models.User)
|
|
if currentUser.PasswordChangeRequired && r.URL.Path != "/reset_password" {
|
|
q := r.URL.Query()
|
|
q.Set("next", r.URL.Path)
|
|
http.Redirect(w, r, fmt.Sprintf("/reset_password?%s", q.Encode()), http.StatusTemporaryRedirect)
|
|
return
|
|
}
|
|
handler.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
q := r.URL.Query()
|
|
q.Set("next", r.URL.Path)
|
|
http.Redirect(w, r, fmt.Sprintf("/login?%s", q.Encode()), http.StatusTemporaryRedirect)
|
|
}
|
|
}
|
|
|
|
// EnforceViewOnly is a global middleware that limits the ability to edit
|
|
// objects to accounts with the PermissionModifyObjects permission.
|
|
func EnforceViewOnly(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// If the request is for any non-GET HTTP method, e.g. POST, PUT,
|
|
// or DELETE, we need to ensure the user has the appropriate
|
|
// permission.
|
|
if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions {
|
|
user := ctx.Get(r, "user").(models.User)
|
|
access, err := user.HasPermission(models.PermissionModifyObjects)
|
|
if err != nil {
|
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if !access {
|
|
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
|
return
|
|
}
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// RequirePermission checks to see if the user has the requested permission
|
|
// before executing the handler. If the request is unauthorized, a JSONError
|
|
// is returned.
|
|
func RequirePermission(perm string) func(http.Handler) http.HandlerFunc {
|
|
return func(next http.Handler) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
user := ctx.Get(r, "user").(models.User)
|
|
access, err := user.HasPermission(perm)
|
|
if err != nil {
|
|
JSONError(w, http.StatusInternalServerError, err.Error())
|
|
return
|
|
}
|
|
if !access {
|
|
JSONError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ApplySecurityHeaders applies various security headers according to best-
|
|
// practices.
|
|
func ApplySecurityHeaders(next http.Handler) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
csp := "frame-ancestors 'none';"
|
|
w.Header().Set("Content-Security-Policy", csp)
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
next.ServeHTTP(w, r)
|
|
}
|
|
}
|
|
|
|
// JSONError returns an error in JSON format with the given
|
|
// status code and message
|
|
func JSONError(w http.ResponseWriter, c int, m string) {
|
|
cj, _ := json.MarshalIndent(models.Response{Success: false, Message: m}, "", " ")
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(c)
|
|
fmt.Fprintf(w, "%s", cj)
|
|
}
|