mirror of https://github.com/gophish/gophish
Refactor GetUserByUsername method not to suppress an error (#920)
Also adding some other tests for the User models.pull/919/head
parent
db19f0ac2a
commit
405bc5effe
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/gophish/gophish/models"
|
"github.com/gophish/gophish/models"
|
||||||
"github.com/gorilla/securecookie"
|
"github.com/gorilla/securecookie"
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,7 +45,7 @@ var ErrPasswordMismatch = errors.New("Passwords must match")
|
||||||
func Login(r *http.Request) (bool, models.User, error) {
|
func Login(r *http.Request) (bool, models.User, error) {
|
||||||
username, password := r.FormValue("username"), r.FormValue("password")
|
username, password := r.FormValue("username"), r.FormValue("password")
|
||||||
u, err := models.GetUserByUsername(username)
|
u, err := models.GetUserByUsername(username)
|
||||||
if err != nil && err != models.ErrUsernameTaken {
|
if err != nil {
|
||||||
return false, models.User{}, err
|
return false, models.User{}, err
|
||||||
}
|
}
|
||||||
//If we've made it here, we should have a valid user stored in u
|
//If we've made it here, we should have a valid user stored in u
|
||||||
|
@ -63,7 +64,7 @@ func Register(r *http.Request) (bool, error) {
|
||||||
confirmPassword := r.FormValue("confirm_password")
|
confirmPassword := r.FormValue("confirm_password")
|
||||||
u, err := models.GetUserByUsername(username)
|
u, err := models.GetUserByUsername(username)
|
||||||
// If we have an error which is not simply indicating that no user was found, report it
|
// If we have an error which is not simply indicating that no user was found, report it
|
||||||
if err != nil {
|
if err != nil && err != gorm.ErrRecordNotFound {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
@ -19,9 +18,6 @@ import (
|
||||||
var db *gorm.DB
|
var db *gorm.DB
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// ErrUsernameTaken is thrown when a user attempts to register a username that is taken.
|
|
||||||
var ErrUsernameTaken = errors.New("username already taken")
|
|
||||||
|
|
||||||
// Logger is a global logger used to show informational, warning, and error messages
|
// Logger is a global logger used to show informational, warning, and error messages
|
||||||
var Logger = log.New(os.Stdout, " ", log.Ldate|log.Ltime|log.Lshortfile)
|
var Logger = log.New(os.Stdout, " ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||||
|
|
||||||
|
|
|
@ -100,6 +100,12 @@ func (s *ModelsSuite) TestGetUserExists(c *check.C) {
|
||||||
c.Assert(u.Username, check.Equals, "admin")
|
c.Assert(u.Username, check.Equals, "admin")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ModelsSuite) TestGetUserByUsernameWithExistingUser(c *check.C) {
|
||||||
|
u, err := GetUserByUsername("admin")
|
||||||
|
c.Assert(err, check.Equals, nil)
|
||||||
|
c.Assert(u.Username, check.Equals, "admin")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ModelsSuite) TestGetUserDoesNotExist(c *check.C) {
|
func (s *ModelsSuite) TestGetUserDoesNotExist(c *check.C) {
|
||||||
u, err := GetUser(100)
|
u, err := GetUser(100)
|
||||||
c.Assert(err, check.Equals, gorm.ErrRecordNotFound)
|
c.Assert(err, check.Equals, gorm.ErrRecordNotFound)
|
||||||
|
@ -111,8 +117,6 @@ func (s *ModelsSuite) TestGetUserByAPIKeyWithExistingAPIKey(c *check.C) {
|
||||||
c.Assert(err, check.Equals, nil)
|
c.Assert(err, check.Equals, nil)
|
||||||
|
|
||||||
u, err = GetUserByAPIKey(u.ApiKey)
|
u, err = GetUserByAPIKey(u.ApiKey)
|
||||||
c.Assert(err, check.Equals, nil)
|
|
||||||
c.Assert(u.Username, check.Equals, "admin")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ModelsSuite) TestGetUserByAPIKeyWithNotExistingAPIKey(c *check.C) {
|
func (s *ModelsSuite) TestGetUserByAPIKeyWithNotExistingAPIKey(c *check.C) {
|
||||||
|
@ -124,6 +128,12 @@ func (s *ModelsSuite) TestGetUserByAPIKeyWithNotExistingAPIKey(c *check.C) {
|
||||||
c.Assert(u.Username, check.Equals, "")
|
c.Assert(u.Username, check.Equals, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ModelsSuite) TestGetUserByUsernameWithNotExistingUser(c *check.C) {
|
||||||
|
u, err := GetUserByUsername("test user does not exist")
|
||||||
|
c.Assert(err, check.Equals, gorm.ErrRecordNotFound)
|
||||||
|
c.Assert(u.Username, check.Equals, "")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ModelsSuite) TestPutUser(c *check.C) {
|
func (s *ModelsSuite) TestPutUser(c *check.C) {
|
||||||
u, err := GetUser(1)
|
u, err := GetUser(1)
|
||||||
u.Username = "admin_changed"
|
u.Username = "admin_changed"
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
package models
|
package models
|
||||||
|
|
||||||
import "github.com/jinzhu/gorm"
|
|
||||||
|
|
||||||
// User represents the user model for gophish.
|
// User represents the user model for gophish.
|
||||||
type User struct {
|
type User struct {
|
||||||
Id int64 `json:"id"`
|
Id int64 `json:"id"`
|
||||||
|
@ -31,12 +29,6 @@ func GetUserByAPIKey(key string) (User, error) {
|
||||||
func GetUserByUsername(username string) (User, error) {
|
func GetUserByUsername(username string) (User, error) {
|
||||||
u := User{}
|
u := User{}
|
||||||
err := db.Where("username = ?", username).First(&u).Error
|
err := db.Where("username = ?", username).First(&u).Error
|
||||||
// No issue if we don't find a record
|
|
||||||
if err == gorm.ErrRecordNotFound {
|
|
||||||
return u, nil
|
|
||||||
} else if err == nil {
|
|
||||||
return u, ErrUsernameTaken
|
|
||||||
}
|
|
||||||
return u, err
|
return u, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue