Refactor GetUserByUsername method not to suppress an error (#920)

Also adding some other tests for the User models.
pull/919/head
Shuhei Kitagawa 2018-01-12 09:37:38 +09:00 committed by Jordan Wright
parent db19f0ac2a
commit 405bc5effe
4 changed files with 16 additions and 17 deletions

View File

@ -13,6 +13,7 @@ import (
"github.com/gophish/gophish/models"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
"github.com/jinzhu/gorm"
"golang.org/x/crypto/bcrypt"
)
@ -44,7 +45,7 @@ var ErrPasswordMismatch = errors.New("Passwords must match")
func Login(r *http.Request) (bool, models.User, error) {
username, password := r.FormValue("username"), r.FormValue("password")
u, err := models.GetUserByUsername(username)
if err != nil && err != models.ErrUsernameTaken {
if err != nil {
return false, models.User{}, err
}
//If we've made it here, we should have a valid user stored in u
@ -63,7 +64,7 @@ func Register(r *http.Request) (bool, error) {
confirmPassword := r.FormValue("confirm_password")
u, err := models.GetUserByUsername(username)
// If we have an error which is not simply indicating that no user was found, report it
if err != nil {
if err != nil && err != gorm.ErrRecordNotFound {
fmt.Println(err)
return false, err
}

View File

@ -2,7 +2,6 @@ package models
import (
"crypto/rand"
"errors"
"fmt"
"io"
"log"
@ -19,9 +18,6 @@ import (
var db *gorm.DB
var err error
// ErrUsernameTaken is thrown when a user attempts to register a username that is taken.
var ErrUsernameTaken = errors.New("username already taken")
// Logger is a global logger used to show informational, warning, and error messages
var Logger = log.New(os.Stdout, " ", log.Ldate|log.Ltime|log.Lshortfile)

View File

@ -100,6 +100,12 @@ func (s *ModelsSuite) TestGetUserExists(c *check.C) {
c.Assert(u.Username, check.Equals, "admin")
}
func (s *ModelsSuite) TestGetUserByUsernameWithExistingUser(c *check.C) {
u, err := GetUserByUsername("admin")
c.Assert(err, check.Equals, nil)
c.Assert(u.Username, check.Equals, "admin")
}
func (s *ModelsSuite) TestGetUserDoesNotExist(c *check.C) {
u, err := GetUser(100)
c.Assert(err, check.Equals, gorm.ErrRecordNotFound)
@ -111,8 +117,6 @@ func (s *ModelsSuite) TestGetUserByAPIKeyWithExistingAPIKey(c *check.C) {
c.Assert(err, check.Equals, nil)
u, err = GetUserByAPIKey(u.ApiKey)
c.Assert(err, check.Equals, nil)
c.Assert(u.Username, check.Equals, "admin")
}
func (s *ModelsSuite) TestGetUserByAPIKeyWithNotExistingAPIKey(c *check.C) {
@ -120,7 +124,13 @@ func (s *ModelsSuite) TestGetUserByAPIKeyWithNotExistingAPIKey(c *check.C) {
c.Assert(err, check.Equals, nil)
u, err = GetUserByAPIKey(u.ApiKey + "test")
c.Assert(err, check.Equals, gorm.ErrRecordNotFound)
c.Assert(err, check.Equals, gorm.ErrRecordNotFound)
c.Assert(u.Username, check.Equals, "")
}
func (s *ModelsSuite) TestGetUserByUsernameWithNotExistingUser(c *check.C) {
u, err := GetUserByUsername("test user does not exist")
c.Assert(err, check.Equals, gorm.ErrRecordNotFound)
c.Assert(u.Username, check.Equals, "")
}

View File

@ -1,7 +1,5 @@
package models
import "github.com/jinzhu/gorm"
// User represents the user model for gophish.
type User struct {
Id int64 `json:"id"`
@ -31,12 +29,6 @@ func GetUserByAPIKey(key string) (User, error) {
func GetUserByUsername(username string) (User, error) {
u := User{}
err := db.Where("username = ?", username).First(&u).Error
// No issue if we don't find a record
if err == gorm.ErrRecordNotFound {
return u, nil
} else if err == nil {
return u, ErrUsernameTaken
}
return u, err
}