mirror of https://github.com/gophish/gophish
Refactor GetUserByUsername method not to suppress an error (#920)
Also adding some other tests for the User models.pull/919/head
parent
db19f0ac2a
commit
405bc5effe
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/gophish/gophish/models"
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/jinzhu/gorm"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
|
@ -44,7 +45,7 @@ var ErrPasswordMismatch = errors.New("Passwords must match")
|
|||
func Login(r *http.Request) (bool, models.User, error) {
|
||||
username, password := r.FormValue("username"), r.FormValue("password")
|
||||
u, err := models.GetUserByUsername(username)
|
||||
if err != nil && err != models.ErrUsernameTaken {
|
||||
if err != nil {
|
||||
return false, models.User{}, err
|
||||
}
|
||||
//If we've made it here, we should have a valid user stored in u
|
||||
|
@ -63,7 +64,7 @@ func Register(r *http.Request) (bool, error) {
|
|||
confirmPassword := r.FormValue("confirm_password")
|
||||
u, err := models.GetUserByUsername(username)
|
||||
// If we have an error which is not simply indicating that no user was found, report it
|
||||
if err != nil {
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
fmt.Println(err)
|
||||
return false, err
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package models
|
|||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
|
@ -19,9 +18,6 @@ import (
|
|||
var db *gorm.DB
|
||||
var err error
|
||||
|
||||
// ErrUsernameTaken is thrown when a user attempts to register a username that is taken.
|
||||
var ErrUsernameTaken = errors.New("username already taken")
|
||||
|
||||
// Logger is a global logger used to show informational, warning, and error messages
|
||||
var Logger = log.New(os.Stdout, " ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||
|
||||
|
|
|
@ -100,6 +100,12 @@ func (s *ModelsSuite) TestGetUserExists(c *check.C) {
|
|||
c.Assert(u.Username, check.Equals, "admin")
|
||||
}
|
||||
|
||||
func (s *ModelsSuite) TestGetUserByUsernameWithExistingUser(c *check.C) {
|
||||
u, err := GetUserByUsername("admin")
|
||||
c.Assert(err, check.Equals, nil)
|
||||
c.Assert(u.Username, check.Equals, "admin")
|
||||
}
|
||||
|
||||
func (s *ModelsSuite) TestGetUserDoesNotExist(c *check.C) {
|
||||
u, err := GetUser(100)
|
||||
c.Assert(err, check.Equals, gorm.ErrRecordNotFound)
|
||||
|
@ -111,8 +117,6 @@ func (s *ModelsSuite) TestGetUserByAPIKeyWithExistingAPIKey(c *check.C) {
|
|||
c.Assert(err, check.Equals, nil)
|
||||
|
||||
u, err = GetUserByAPIKey(u.ApiKey)
|
||||
c.Assert(err, check.Equals, nil)
|
||||
c.Assert(u.Username, check.Equals, "admin")
|
||||
}
|
||||
|
||||
func (s *ModelsSuite) TestGetUserByAPIKeyWithNotExistingAPIKey(c *check.C) {
|
||||
|
@ -124,6 +128,12 @@ func (s *ModelsSuite) TestGetUserByAPIKeyWithNotExistingAPIKey(c *check.C) {
|
|||
c.Assert(u.Username, check.Equals, "")
|
||||
}
|
||||
|
||||
func (s *ModelsSuite) TestGetUserByUsernameWithNotExistingUser(c *check.C) {
|
||||
u, err := GetUserByUsername("test user does not exist")
|
||||
c.Assert(err, check.Equals, gorm.ErrRecordNotFound)
|
||||
c.Assert(u.Username, check.Equals, "")
|
||||
}
|
||||
|
||||
func (s *ModelsSuite) TestPutUser(c *check.C) {
|
||||
u, err := GetUser(1)
|
||||
u.Username = "admin_changed"
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
package models
|
||||
|
||||
import "github.com/jinzhu/gorm"
|
||||
|
||||
// User represents the user model for gophish.
|
||||
type User struct {
|
||||
Id int64 `json:"id"`
|
||||
|
@ -31,12 +29,6 @@ func GetUserByAPIKey(key string) (User, error) {
|
|||
func GetUserByUsername(username string) (User, error) {
|
||||
u := User{}
|
||||
err := db.Where("username = ?", username).First(&u).Error
|
||||
// No issue if we don't find a record
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return u, nil
|
||||
} else if err == nil {
|
||||
return u, ErrUsernameTaken
|
||||
}
|
||||
return u, err
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue