diff --git a/auth/auth.go b/auth/auth.go index 0a134ea2..3deb327c 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -13,6 +13,7 @@ import ( "github.com/gophish/gophish/models" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" + "github.com/jinzhu/gorm" "golang.org/x/crypto/bcrypt" ) @@ -44,7 +45,7 @@ var ErrPasswordMismatch = errors.New("Passwords must match") func Login(r *http.Request) (bool, models.User, error) { username, password := r.FormValue("username"), r.FormValue("password") u, err := models.GetUserByUsername(username) - if err != nil && err != models.ErrUsernameTaken { + if err != nil { return false, models.User{}, err } //If we've made it here, we should have a valid user stored in u @@ -63,7 +64,7 @@ func Register(r *http.Request) (bool, error) { confirmPassword := r.FormValue("confirm_password") u, err := models.GetUserByUsername(username) // If we have an error which is not simply indicating that no user was found, report it - if err != nil { + if err != nil && err != gorm.ErrRecordNotFound { fmt.Println(err) return false, err } diff --git a/models/models.go b/models/models.go index e9ff833c..3261cea7 100644 --- a/models/models.go +++ b/models/models.go @@ -2,7 +2,6 @@ package models import ( "crypto/rand" - "errors" "fmt" "io" "log" @@ -19,9 +18,6 @@ import ( var db *gorm.DB var err error -// ErrUsernameTaken is thrown when a user attempts to register a username that is taken. -var ErrUsernameTaken = errors.New("username already taken") - // Logger is a global logger used to show informational, warning, and error messages var Logger = log.New(os.Stdout, " ", log.Ldate|log.Ltime|log.Lshortfile) diff --git a/models/models_test.go b/models/models_test.go index 5bcd7126..e4b507f4 100644 --- a/models/models_test.go +++ b/models/models_test.go @@ -100,6 +100,12 @@ func (s *ModelsSuite) TestGetUserExists(c *check.C) { c.Assert(u.Username, check.Equals, "admin") } +func (s *ModelsSuite) TestGetUserByUsernameWithExistingUser(c *check.C) { + u, err := GetUserByUsername("admin") + c.Assert(err, check.Equals, nil) + c.Assert(u.Username, check.Equals, "admin") +} + func (s *ModelsSuite) TestGetUserDoesNotExist(c *check.C) { u, err := GetUser(100) c.Assert(err, check.Equals, gorm.ErrRecordNotFound) @@ -111,8 +117,6 @@ func (s *ModelsSuite) TestGetUserByAPIKeyWithExistingAPIKey(c *check.C) { c.Assert(err, check.Equals, nil) u, err = GetUserByAPIKey(u.ApiKey) - c.Assert(err, check.Equals, nil) - c.Assert(u.Username, check.Equals, "admin") } func (s *ModelsSuite) TestGetUserByAPIKeyWithNotExistingAPIKey(c *check.C) { @@ -120,7 +124,13 @@ func (s *ModelsSuite) TestGetUserByAPIKeyWithNotExistingAPIKey(c *check.C) { c.Assert(err, check.Equals, nil) u, err = GetUserByAPIKey(u.ApiKey + "test") - c.Assert(err, check.Equals, gorm.ErrRecordNotFound) + c.Assert(err, check.Equals, gorm.ErrRecordNotFound) + c.Assert(u.Username, check.Equals, "") +} + +func (s *ModelsSuite) TestGetUserByUsernameWithNotExistingUser(c *check.C) { + u, err := GetUserByUsername("test user does not exist") + c.Assert(err, check.Equals, gorm.ErrRecordNotFound) c.Assert(u.Username, check.Equals, "") } diff --git a/models/user.go b/models/user.go index 81f9081f..fe8761de 100644 --- a/models/user.go +++ b/models/user.go @@ -1,7 +1,5 @@ package models -import "github.com/jinzhu/gorm" - // User represents the user model for gophish. type User struct { Id int64 `json:"id"` @@ -31,12 +29,6 @@ func GetUserByAPIKey(key string) (User, error) { func GetUserByUsername(username string) (User, error) { u := User{} err := db.Where("username = ?", username).First(&u).Error - // No issue if we don't find a record - if err == gorm.ErrRecordNotFound { - return u, nil - } else if err == nil { - return u, ErrUsernameTaken - } return u, err }