diff --git a/auth/auth.go b/auth/auth.go index b9fc8701..191edbed 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -3,8 +3,12 @@ package auth import ( "database/sql" "encoding/gob" + "errors" + "fmt" + "io" "net/http" + "crypto/rand" "code.google.com/p/go.crypto/bcrypt" ctx "github.com/gorilla/context" "github.com/gorilla/securecookie" @@ -23,9 +27,9 @@ var Store = sessions.NewCookieStore( []byte(securecookie.GenerateRandomKey(64)), //Signing key []byte(securecookie.GenerateRandomKey(32))) -// CheckLogin attempts to request a SQL record with the given username. -// If successful, it then compares the received bcrypt hash. -// If all checks pass, this function sets the session id for later use. +var ErrUsernameTaken = errors.New("Username already taken") + +// Login attempts to login the user given a request. func Login(r *http.Request) (bool, error) { username, password := r.FormValue("username"), r.FormValue("password") session, _ := Store.Get(r, "gophish") @@ -50,6 +54,30 @@ func Login(r *http.Request) (bool, error) { return true, nil } +// Register attempts to register the user given a request. +func Register(r *http.Request) (bool, error) { + username, password := r.FormValue("username"), r.FormValue("password") + u := models.User{} + err := db.Conn.SelectOne(&u, "SELECT * FROM Users WHERE username=?", username) + if err != sql.ErrNoRows { + return false, ErrUsernameTaken + } + //If we've made it here, we should have a valid username given + //Let's create the password hash + h, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + u.Username = username + u.Hash = string(h) + u.APIKey = GenerateSecureKey() + if err != nil { + return false, err + } + err = db.Conn.Insert(&u) + if err != nil { + return false, err + } + return true, nil +} + // GetUserById returns the user that the given id corresponds to. If no user is found, an // error is thrown. func GetUserById(id int64) (models.User, error) { @@ -71,3 +99,10 @@ func GetUserByAPIKey(key []byte) (models.User, error) { } return u, nil } + +func GenerateSecureKey() string { + // Inspired from gorilla/securecookie + k := make([]byte, 32) + io.ReadFull(rand.Reader, k) + return fmt.Sprintf("%x", k) +} diff --git a/controllers/api.go b/controllers/api.go index 271f74b9..4e955fdb 100644 --- a/controllers/api.go +++ b/controllers/api.go @@ -1,10 +1,8 @@ package controllers import ( - "crypto/rand" "encoding/json" "fmt" - "io" "net/http" "strconv" "time" @@ -12,6 +10,7 @@ import ( ctx "github.com/gorilla/context" "github.com/gorilla/mux" "github.com/gorilla/sessions" + "github.com/jordan-wright/gophish/auth" "github.com/jordan-wright/gophish/db" "github.com/jordan-wright/gophish/models" ) @@ -36,11 +35,7 @@ func API_Reset(w http.ResponseWriter, r *http.Request) { switch { case r.Method == "POST": u := ctx.Get(r, "user").(models.User) - // Inspired from gorilla/securecookie - k := make([]byte, 32) - _, err := io.ReadFull(rand.Reader, k) - checkError(err, w, "Error setting new API key") - u.APIKey = fmt.Sprintf("%x", k) + u.APIKey = auth.GenerateSecureKey() db.Conn.Exec("UPDATE users SET api_key=? WHERE id=?", u.APIKey, u.Id) session := ctx.Get(r, "session").(*sessions.Session) session.AddFlash(models.Flash{ @@ -75,19 +70,14 @@ func API_Campaigns(w http.ResponseWriter, r *http.Request) { c := models.Campaign{} // Put the request into a campaign err := json.NewDecoder(r.Body).Decode(&c) - checkError(err, w, "Invalid Request") + if checkError(err, w, "Invalid Request") { + return + } // Fill in the details c.CreatedDate = time.Now() c.CompletedDate = time.Time{} c.Status = IN_PROGRESS - c.Uid, err = db.Conn.SelectInt("SELECT id FROM users WHERE api_key=?", ctx.Get(r, "api_key")) - if c.Uid == 0 { - http.Error(w, "Error: Invalid API Key", http.StatusInternalServerError) - return - } - if checkError(err, w, "Invalid API Key") { - return - } + c.Uid = ctx.Get(r, "user_id").(int64) // Insert into the DB err = db.Conn.Insert(&c) if checkError(err, w, "Cannot insert campaign into database") { @@ -133,7 +123,49 @@ func API_Campaigns_Id_Launch(w http.ResponseWriter, r *http.Request) { // API_Groups returns details about the requested group. If the campaign is not // valid, API_Groups returns null. func API_Groups(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "/", 302) + switch { + case r.Method == "GET": + gs := []models.Group{} + _, err := db.Conn.Select(&gs, "SELECT g.id, g.name, g.modified_date FROM groups g, users u WHERE g.uid=u.id AND u.api_key=?", ctx.Get(r, "api_key")) + if err != nil { + fmt.Println(err) + } + for _, g := range gs { + _, err := db.Conn.Select(&g.Targets, "SELECT t.id t.email FROM targets t, groups g, group_targets gt WHERE gt.gid=? AND gt.tid=t.id", g.Id) + if checkError(err, w, "Error looking up groups") { + return + } + } + gj, err := json.MarshalIndent(gs, "", " ") + if checkError(err, w, "Error looking up groups") { + return + } + writeJSON(w, gj) + //POST: Create a new group and return it as JSON + case r.Method == "POST": + g := models.Group{} + // Put the request into a group + err := json.NewDecoder(r.Body).Decode(&g) + if checkError(err, w, "Invalid Request") { + return + } + // Check to make sure targets were specified + if len(g.Targets) == 0 { + http.Error(w, "Error: No targets specified", http.StatusInternalServerError) + return + } + g.ModifiedDate = time.Now() + // Insert into the DB + err = db.Conn.Insert(&g) + if checkError(err, w, "Cannot insert group into database") { + return + } + gj, err := json.MarshalIndent(g, "", " ") + if checkError(err, w, "Error creating JSON response") { + return + } + writeJSON(w, gj) + } } // API_Campaigns_Id returns details about the requested campaign. If the campaign is not diff --git a/controllers/route.go b/controllers/route.go index bf8337c8..1a141ab4 100644 --- a/controllers/route.go +++ b/controllers/route.go @@ -61,7 +61,48 @@ func Use(handler http.HandlerFunc, mid ...func(http.Handler) http.HandlerFunc) h func Register(w http.ResponseWriter, r *http.Request) { // If it is a post request, attempt to register the account // Now that we are all registered, we can log the user in - Login(w, r) + params := struct { + Title string + Flashes []interface{} + User models.User + Token string + }{Title: "Register", Token: nosurf.Token(r)} + session := ctx.Get(r, "session").(*sessions.Session) + switch { + case r.Method == "GET": + params.Flashes = session.Flashes() + session.Save(r, w) + getTemplate(w, "register").ExecuteTemplate(w, "base", params) + case r.Method == "POST": + //Attempt to register + succ, err := auth.Register(r) + //If we've registered, redirect to the login page + if succ { + session.AddFlash(models.Flash{ + Type: "success", + Message: "Registration successful!.", + }) + session.Save(r, w) + http.Redirect(w, r, "/login", 302) + } else { + // Check the error + m := "" + if err == auth.ErrUsernameTaken { + m = "Username already taken" + } else { + m = "Unknown error - please try again" + fmt.Println(err) + } + fmt.Println(m) + session.AddFlash(models.Flash{ + Type: "danger", + Message: m, + }) + session.Save(r, w) + http.Redirect(w, r, "/register", 302) + } + + } } func Logout(w http.ResponseWriter, r *http.Request) { @@ -134,10 +175,6 @@ func Login(w http.ResponseWriter, r *http.Request) { getTemplate(w, "login").ExecuteTemplate(w, "base", params) case r.Method == "POST": //Attempt to login - err := r.ParseForm() - if checkError(err, w, "Error parsing request") { - return - } succ, err := auth.Login(r) if checkError(err, w, "Error logging in") { return diff --git a/db/db.go b/db/db.go index 9a912db8..15eb73fa 100644 --- a/db/db.go +++ b/db/db.go @@ -30,9 +30,9 @@ func Setup() error { fmt.Println("Database not found, recreating...") createTablesSQL := []string{ //Create tables - `CREATE TABLE users (id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL, hash VARCHAR(60) NOT NULL, api_key VARCHAR(32));`, + `CREATE TABLE users (id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL, hash VARCHAR(60) NOT NULL, api_key VARCHAR(32), UNIQUE(username), UNIQUE(api_key));`, `CREATE TABLE campaigns (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, created_date TIMESTAMP NOT NULL, completed_date TIMESTAMP, template TEXT, status TEXT NOT NULL, uid INTEGER, FOREIGN KEY (uid) REFERENCES users(id));`, - `CREATE TABLE targets (id INTEGER PRIMARY KEY AUTOINCREMENT, address TEXT NOT NULL);`, + `CREATE TABLE targets (id INTEGER PRIMARY KEY AUTOINCREMENT, address TEXT NOT NULL, UNIQUE(address));`, `CREATE TABLE groups (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, modified_date TIMESTAMP NOT NULL);`, `CREATE TABLE user_groups (uid INTEGER NOT NULL, gid INTEGER NOT NULL, FOREIGN KEY (uid) REFERENCES users(id), FOREIGN KEY (gid) REFERENCES groups(id), UNIQUE(uid, gid))`, `CREATE TABLE group_targets (gid INTEGER NOT NULL, tid INTEGER NOT NULL, FOREIGN KEY (gid) REFERENCES groups(id), FOREIGN KEY (tid) REFERENCES targets(id), UNIQUE(gid, tid));`, diff --git a/templates/register.html b/templates/register.html new file mode 100644 index 00000000..38085385 --- /dev/null +++ b/templates/register.html @@ -0,0 +1,13 @@ +{{%define "content"%}} +