diff --git a/auth/auth.go b/auth/auth.go index 382c1443..651adc67 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -51,9 +51,16 @@ func CheckLogin(r *http.Request) (bool, error) { return true, nil } -func GetUser(r *http.Request) models.User { - if rv := ctx.Get(r, "user"); rv != nil { - return rv.(models.User) +func GetUser(id int) (models.User, error) { + u := models.User{} + stmt, err := db.Conn.Prepare("SELECT * FROM Users WHERE id=?") + if err != nil { + return u, err } - return models.User{} + err = stmt.QueryRow(id).Scan(&u.Id, &u.Username, &u.Hash, &u.APIKey) + if err != nil { + //Return false, but don't return an error + return u, err + } + return u, nil } diff --git a/controllers/route.go b/controllers/route.go index 03005e0e..a4d76835 100644 --- a/controllers/route.go +++ b/controllers/route.go @@ -27,6 +27,7 @@ THE SOFTWARE. */ import ( + "fmt" "html/template" "net/http" @@ -34,13 +35,14 @@ import ( "github.com/gorilla/mux" "github.com/gorilla/sessions" "github.com/jordan-wright/gophish/auth" + "github.com/jordan-wright/gophish/middleware" "github.com/jordan-wright/gophish/models" ) func CreateRouter() http.Handler { router := mux.NewRouter() // Base Front-end routes - router.HandleFunc("/", Base) + router.Handle("/", middleware.Use(http.HandlerFunc(Base), middleware.RequireLogin)) router.HandleFunc("/login", Login) router.HandleFunc("/register", Register) router.HandleFunc("/campaigns", Base_Campaigns) @@ -67,6 +69,13 @@ func Register(w http.ResponseWriter, r *http.Request) { func Base(w http.ResponseWriter, r *http.Request) { // Example of using session - will be removed. + params := struct { + User models.User + Title string + Flashes []interface{} + }{} + params.User = ctx.Get(r, "user").(models.User) + fmt.Println(params.User.Username) getTemplate(w, "dashboard").ExecuteTemplate(w, "base", nil) } diff --git a/middleware/middleware.go b/middleware/middleware.go index a1157803..8320fb57 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -1,7 +1,6 @@ package middleware import ( - "fmt" "net/http" ctx "github.com/gorilla/context" @@ -25,10 +24,20 @@ func GetContext(handler http.Handler) http.Handler { // Set the context appropriately here. // Set the session session, _ := auth.Store.Get(r, "gophish") + // Put the session in the context so that ctx.Set(r, "session", session) + if id, ok := session.Values["id"]; ok { + u, err := auth.GetUser(id.(int)) + if err != nil { + ctx.Set(r, "user", nil) + } + ctx.Set(r, "user", u) + } else { + ctx.Set(r, "user", nil) + } handler.ServeHTTP(w, r) // Save the session - session.Save() + session.Save(r, w) // Remove context contents ctx.Clear(r) }) @@ -38,7 +47,10 @@ func GetContext(handler http.Handler) http.Handler { // If not, the function returns a 302 redirect to the login page. func RequireLogin(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Println("RequireLogin called!!") - handler.ServeHTTP(w, r) + if u := ctx.Get(r, "user"); u != nil { + handler.ServeHTTP(w, r) + } else { + http.Redirect(w, r, "/login", 302) + } }) }