Refactor GetUserByUsername method not to suppress an error (#920)

Also adding some other tests for the User models.
pull/919/head
Shuhei Kitagawa 2018-01-12 09:37:38 +09:00 committed by Jordan Wright
parent db19f0ac2a
commit 405bc5effe
4 changed files with 16 additions and 17 deletions

View File

@ -13,6 +13,7 @@ import (
"github.com/gophish/gophish/models" "github.com/gophish/gophish/models"
"github.com/gorilla/securecookie" "github.com/gorilla/securecookie"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/jinzhu/gorm"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -44,7 +45,7 @@ var ErrPasswordMismatch = errors.New("Passwords must match")
func Login(r *http.Request) (bool, models.User, error) { func Login(r *http.Request) (bool, models.User, error) {
username, password := r.FormValue("username"), r.FormValue("password") username, password := r.FormValue("username"), r.FormValue("password")
u, err := models.GetUserByUsername(username) u, err := models.GetUserByUsername(username)
if err != nil && err != models.ErrUsernameTaken { if err != nil {
return false, models.User{}, err return false, models.User{}, err
} }
//If we've made it here, we should have a valid user stored in u //If we've made it here, we should have a valid user stored in u
@ -63,7 +64,7 @@ func Register(r *http.Request) (bool, error) {
confirmPassword := r.FormValue("confirm_password") confirmPassword := r.FormValue("confirm_password")
u, err := models.GetUserByUsername(username) u, err := models.GetUserByUsername(username)
// If we have an error which is not simply indicating that no user was found, report it // If we have an error which is not simply indicating that no user was found, report it
if err != nil { if err != nil && err != gorm.ErrRecordNotFound {
fmt.Println(err) fmt.Println(err)
return false, err return false, err
} }

View File

@ -2,7 +2,6 @@ package models
import ( import (
"crypto/rand" "crypto/rand"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -19,9 +18,6 @@ import (
var db *gorm.DB var db *gorm.DB
var err error var err error
// ErrUsernameTaken is thrown when a user attempts to register a username that is taken.
var ErrUsernameTaken = errors.New("username already taken")
// Logger is a global logger used to show informational, warning, and error messages // Logger is a global logger used to show informational, warning, and error messages
var Logger = log.New(os.Stdout, " ", log.Ldate|log.Ltime|log.Lshortfile) var Logger = log.New(os.Stdout, " ", log.Ldate|log.Ltime|log.Lshortfile)

View File

@ -100,6 +100,12 @@ func (s *ModelsSuite) TestGetUserExists(c *check.C) {
c.Assert(u.Username, check.Equals, "admin") c.Assert(u.Username, check.Equals, "admin")
} }
func (s *ModelsSuite) TestGetUserByUsernameWithExistingUser(c *check.C) {
u, err := GetUserByUsername("admin")
c.Assert(err, check.Equals, nil)
c.Assert(u.Username, check.Equals, "admin")
}
func (s *ModelsSuite) TestGetUserDoesNotExist(c *check.C) { func (s *ModelsSuite) TestGetUserDoesNotExist(c *check.C) {
u, err := GetUser(100) u, err := GetUser(100)
c.Assert(err, check.Equals, gorm.ErrRecordNotFound) c.Assert(err, check.Equals, gorm.ErrRecordNotFound)
@ -111,8 +117,6 @@ func (s *ModelsSuite) TestGetUserByAPIKeyWithExistingAPIKey(c *check.C) {
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)
u, err = GetUserByAPIKey(u.ApiKey) u, err = GetUserByAPIKey(u.ApiKey)
c.Assert(err, check.Equals, nil)
c.Assert(u.Username, check.Equals, "admin")
} }
func (s *ModelsSuite) TestGetUserByAPIKeyWithNotExistingAPIKey(c *check.C) { func (s *ModelsSuite) TestGetUserByAPIKeyWithNotExistingAPIKey(c *check.C) {
@ -124,6 +128,12 @@ func (s *ModelsSuite) TestGetUserByAPIKeyWithNotExistingAPIKey(c *check.C) {
c.Assert(u.Username, check.Equals, "") c.Assert(u.Username, check.Equals, "")
} }
func (s *ModelsSuite) TestGetUserByUsernameWithNotExistingUser(c *check.C) {
u, err := GetUserByUsername("test user does not exist")
c.Assert(err, check.Equals, gorm.ErrRecordNotFound)
c.Assert(u.Username, check.Equals, "")
}
func (s *ModelsSuite) TestPutUser(c *check.C) { func (s *ModelsSuite) TestPutUser(c *check.C) {
u, err := GetUser(1) u, err := GetUser(1)
u.Username = "admin_changed" u.Username = "admin_changed"

View File

@ -1,7 +1,5 @@
package models package models
import "github.com/jinzhu/gorm"
// User represents the user model for gophish. // User represents the user model for gophish.
type User struct { type User struct {
Id int64 `json:"id"` Id int64 `json:"id"`
@ -31,12 +29,6 @@ func GetUserByAPIKey(key string) (User, error) {
func GetUserByUsername(username string) (User, error) { func GetUserByUsername(username string) (User, error) {
u := User{} u := User{}
err := db.Where("username = ?", username).First(&u).Error err := db.Where("username = ?", username).First(&u).Error
// No issue if we don't find a record
if err == gorm.ErrRecordNotFound {
return u, nil
} else if err == nil {
return u, ErrUsernameTaken
}
return u, err return u, err
} }