Adding "next" parameter to support redirecting after successful login.

pull/890/head
Jordan Wright 2017-12-10 21:40:46 -06:00
parent 227da5c7b9
commit aa8c770e73
4 changed files with 51 additions and 5 deletions

View File

@ -5,6 +5,7 @@ import (
"html/template" "html/template"
"log" "log"
"net/http" "net/http"
"net/url"
"os" "os"
"github.com/gophish/gophish/auth" "github.com/gophish/gophish/auth"
@ -267,7 +268,15 @@ func Login(w http.ResponseWriter, r *http.Request) {
if succ { if succ {
session.Values["id"] = u.Id session.Values["id"] = u.Id
session.Save(r, w) session.Save(r, w)
http.Redirect(w, r, "/", 302) next := "/"
url, err := url.Parse(r.FormValue("next"))
if err == nil {
path := url.Path
if path != "" {
next = path
}
}
http.Redirect(w, r, next, 302)
} else { } else {
Flash(w, r, "danger", "Invalid Username/Password") Flash(w, r, "danger", "Invalid Username/Password")
params.Flashes = session.Flashes() params.Flashes = session.Flashes()

View File

@ -73,3 +73,38 @@ func (s *ControllersSuite) TestSuccessfulLogin() {
s.Equal(err, nil) s.Equal(err, nil)
s.Equal(resp.StatusCode, http.StatusOK) s.Equal(resp.StatusCode, http.StatusOK)
} }
func (s *ControllersSuite) TestSuccessfulRedirect() {
next := "/campaigns"
resp, err := http.Get(fmt.Sprintf("%s/login", as.URL))
s.Equal(err, nil)
s.Equal(resp.StatusCode, http.StatusOK)
doc, err := goquery.NewDocumentFromResponse(resp)
s.Equal(err, nil)
elem := doc.Find("input[name='csrf_token']").First()
token, ok := elem.Attr("value")
s.Equal(ok, true)
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
req, err := http.NewRequest("POST", fmt.Sprintf("%s/login?next=%s", as.URL, next), strings.NewReader(url.Values{
"username": {"admin"},
"password": {"gophish"},
"csrf_token": {token},
}.Encode()))
s.Equal(err, nil)
req.Header.Set("Cookie", resp.Header.Get("Set-Cookie"))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err = client.Do(req)
s.Equal(err, nil)
s.Equal(resp.StatusCode, http.StatusFound)
url, err := resp.Location()
s.Equal(err, nil)
s.Equal(url.Path, next)
}

View File

@ -94,7 +94,9 @@ func RequireLogin(handler http.Handler) http.HandlerFunc {
if u := ctx.Get(r, "user"); u != nil { if u := ctx.Get(r, "user"); u != nil {
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
} else { } else {
http.Redirect(w, r, "/login", 302) q := r.URL.Query()
q.Set("next", r.URL.Path)
http.Redirect(w, r, fmt.Sprintf("/login?%s", q.Encode()), 302)
} }
} }
} }

View File

@ -41,7 +41,7 @@
</div> </div>
</div> </div>
<div class="container"> <div class="container">
<form class="form-signin" action="/login" method="POST"> <form class="form-signin" action="" method="POST">
<img id="logo" src="/images/logo_purple.png" /> <img id="logo" src="/images/logo_purple.png" />
<h2 class="form-signin-heading">Please sign in</h2> <h2 class="form-signin-heading">Please sign in</h2>
{{template "flashes" .Flashes}} {{template "flashes" .Flashes}}