Refactoring tests to remove stretchr/testify dependency

pull/1741/head
Jordan Wright 2020-02-01 21:44:50 -06:00 committed by GitHub
parent e12258bf25
commit be459e47bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 765 additions and 532 deletions

View File

@ -4,16 +4,10 @@ import (
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
"os" "os"
"reflect"
"testing" "testing"
"github.com/stretchr/testify/suite"
) )
type ConfigSuite struct {
suite.Suite
ConfigFile *os.File
}
var validConfig = []byte(`{ var validConfig = []byte(`{
"admin_server": { "admin_server": {
"listen_url": "127.0.0.1:3333", "listen_url": "127.0.0.1:3333",
@ -33,36 +27,48 @@ var validConfig = []byte(`{
"contact_address": "" "contact_address": ""
}`) }`)
func (s *ConfigSuite) SetupTest() { func createTemporaryConfig(t *testing.T) *os.File {
f, err := ioutil.TempFile("", "gophish-config") f, err := ioutil.TempFile("", "gophish-config")
s.Nil(err) if err != nil {
s.ConfigFile = f t.Fatalf("unable to create temporary config: %v", err)
}
return f
} }
func (s *ConfigSuite) TearDownTest() { func removeTemporaryConfig(t *testing.T, f *os.File) {
err := s.ConfigFile.Close() err := f.Close()
s.Nil(err) if err != nil {
t.Fatalf("unable to remove temporary config: %v", err)
}
} }
func (s *ConfigSuite) TestLoadConfig() { func TestLoadConfig(t *testing.T) {
_, err := s.ConfigFile.Write(validConfig) f := createTemporaryConfig(t)
s.Nil(err) defer removeTemporaryConfig(t, f)
_, err := f.Write(validConfig)
if err != nil {
t.Fatalf("error writing config to temporary file: %v", err)
}
// Load the valid config // Load the valid config
conf, err := LoadConfig(s.ConfigFile.Name()) conf, err := LoadConfig(f.Name())
s.Nil(err) if err != nil {
t.Fatalf("error loading config from temporary file: %v", err)
}
expectedConfig := &Config{} expectedConfig := &Config{}
err = json.Unmarshal(validConfig, &expectedConfig) err = json.Unmarshal(validConfig, &expectedConfig)
s.Nil(err) if err != nil {
t.Fatalf("error unmarshaling config: %v", err)
}
expectedConfig.MigrationsPath = expectedConfig.MigrationsPath + expectedConfig.DBName expectedConfig.MigrationsPath = expectedConfig.MigrationsPath + expectedConfig.DBName
expectedConfig.TestFlag = false expectedConfig.TestFlag = false
s.Equal(expectedConfig, conf) if !reflect.DeepEqual(expectedConfig, conf) {
t.Fatalf("invalid config received. expected %#v got %#v", expectedConfig, conf)
}
// Load an invalid config // Load an invalid config
conf, err = LoadConfig("bogusfile") conf, err = LoadConfig("bogusfile")
s.NotNil(err) if err == nil {
t.Fatalf("expected error when loading invalid config, but got %v", err)
} }
func TestConfigSuite(t *testing.T) {
suite.Run(t, new(ConfigSuite))
} }

View File

@ -6,23 +6,20 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"testing" "testing"
"github.com/gophish/gophish/config" "github.com/gophish/gophish/config"
"github.com/gophish/gophish/models" "github.com/gophish/gophish/models"
"github.com/stretchr/testify/suite"
) )
type APISuite struct { type testContext struct {
suite.Suite
apiKey string apiKey string
config *config.Config config *config.Config
apiServer *Server apiServer *Server
admin models.User admin models.User
} }
func (s *APISuite) SetupSuite() { func setupTest(t *testing.T) *testContext {
conf := &config.Config{ conf := &config.Config{
DBName: "sqlite3", DBName: "sqlite3",
DBPath: ":memory:", DBPath: ":memory:",
@ -30,39 +27,34 @@ func (s *APISuite) SetupSuite() {
} }
err := models.Setup(conf) err := models.Setup(conf)
if err != nil { if err != nil {
s.T().Fatalf("Failed creating database: %v", err) t.Fatalf("Failed creating database: %v", err)
} }
s.config = conf ctx := &testContext{}
s.Nil(err) ctx.config = conf
// Get the API key to use for these tests // Get the API key to use for these tests
u, err := models.GetUser(1) u, err := models.GetUser(1)
s.Nil(err) if err != nil {
s.apiKey = u.ApiKey t.Fatalf("error getting admin user: %v", err)
s.admin = u }
// Move our cwd up to the project root for help with resolving ctx.apiKey = u.ApiKey
// static assets ctx.admin = u
err = os.Chdir("../") ctx.apiServer = NewServer()
s.Nil(err) return ctx
s.apiServer = NewServer()
} }
func (s *APISuite) TearDownTest() { func tearDown(t *testing.T, ctx *testContext) {
campaigns, _ := models.GetCampaigns(1)
for _, campaign := range campaigns {
models.DeleteCampaign(campaign.Id)
}
// Cleanup all users except the original admin // Cleanup all users except the original admin
users, _ := models.GetUsers() // users, _ := models.GetUsers()
for _, user := range users { // for _, user := range users {
if user.Id == 1 { // if user.Id == 1 {
continue // continue
} // }
err := models.DeleteUser(user.Id) // err := models.DeleteUser(user.Id)
s.Nil(err) // s.Nil(err)
} // }
} }
func (s *APISuite) SetupTest() { func createTestData(t *testing.T) {
// Add a group // Add a group
group := models.Group{Name: "Test Group"} group := models.Group{Name: "Test Group"}
group.Targets = []models.Target{ group.Targets = []models.Target{
@ -73,12 +65,12 @@ func (s *APISuite) SetupTest() {
models.PostGroup(&group) models.PostGroup(&group)
// Add a template // Add a template
t := models.Template{Name: "Test Template"} template := models.Template{Name: "Test Template"}
t.Subject = "Test subject" template.Subject = "Test subject"
t.Text = "Text text" template.Text = "Text text"
t.HTML = "<html>Test</html>" template.HTML = "<html>Test</html>"
t.UserId = 1 template.UserId = 1
models.PostTemplate(&t) models.PostTemplate(&template)
// Add a landing page // Add a landing page
p := models.Page{Name: "Test Page"} p := models.Page{Name: "Test Page"}
@ -97,7 +89,7 @@ func (s *APISuite) SetupTest() {
// Set the status such that no emails are attempted // Set the status such that no emails are attempted
c := models.Campaign{Name: "Test campaign"} c := models.Campaign{Name: "Test campaign"}
c.UserId = 1 c.UserId = 1
c.Template = t c.Template = template
c.Page = p c.Page = p
c.SMTP = smtp c.SMTP = smtp
c.Groups = []models.Group{group} c.Groups = []models.Group{group}
@ -105,12 +97,13 @@ func (s *APISuite) SetupTest() {
c.UpdateStatus(models.CampaignEmailsSent) c.UpdateStatus(models.CampaignEmailsSent)
} }
func (s *APISuite) TestSiteImportBaseHref() { func TestSiteImportBaseHref(t *testing.T) {
ctx := setupTest(t)
h := "<html><head></head><body><img src=\"/test.png\"/></body></html>" h := "<html><head></head><body><img src=\"/test.png\"/></body></html>"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, h) fmt.Fprintln(w, h)
})) }))
hr := fmt.Sprintf("<html><head><base href=\"%s\"/></head><body><img src=\"/test.png\"/>\n</body></html>", ts.URL) expected := fmt.Sprintf("<html><head><base href=\"%s\"/></head><body><img src=\"/test.png\"/>\n</body></html>", ts.URL)
defer ts.Close() defer ts.Close()
req := httptest.NewRequest(http.MethodPost, "/api/import/site", req := httptest.NewRequest(http.MethodPost, "/api/import/site",
bytes.NewBuffer([]byte(fmt.Sprintf(` bytes.NewBuffer([]byte(fmt.Sprintf(`
@ -121,13 +114,13 @@ func (s *APISuite) TestSiteImportBaseHref() {
`, ts.URL)))) `, ts.URL))))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
response := httptest.NewRecorder() response := httptest.NewRecorder()
s.apiServer.ImportSite(response, req) ctx.apiServer.ImportSite(response, req)
cs := cloneResponse{} cs := cloneResponse{}
err := json.NewDecoder(response.Body).Decode(&cs) err := json.NewDecoder(response.Body).Decode(&cs)
s.Nil(err) if err != nil {
s.Equal(cs.HTML, hr) t.Fatalf("error decoding response: %v", err)
}
if cs.HTML != expected {
t.Fatalf("unexpected response received. expected %s got %s", expected, cs.HTML)
} }
func TestAPISuite(t *testing.T) {
suite.Run(t, new(APISuite))
} }

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@ -13,9 +14,11 @@ import (
"github.com/gophish/gophish/models" "github.com/gophish/gophish/models"
) )
func (s *APISuite) createUnpriviledgedUser(slug string) *models.User { func createUnpriviledgedUser(t *testing.T, slug string) *models.User {
role, err := models.GetRoleBySlug(slug) role, err := models.GetRoleBySlug(slug)
s.Nil(err) if err != nil {
t.Fatalf("error getting role by slug: %v", err)
}
unauthorizedUser := &models.User{ unauthorizedUser := &models.User{
Username: "foo", Username: "foo",
Hash: "bar", Hash: "bar",
@ -24,56 +27,82 @@ func (s *APISuite) createUnpriviledgedUser(slug string) *models.User {
RoleID: role.ID, RoleID: role.ID,
} }
err = models.PutUser(unauthorizedUser) err = models.PutUser(unauthorizedUser)
s.Nil(err) if err != nil {
t.Fatalf("error saving unpriviledged user: %v", err)
}
return unauthorizedUser return unauthorizedUser
} }
func (s *APISuite) TestGetUsers() { func TestGetUsers(t *testing.T) {
testCtx := setupTest(t)
r := httptest.NewRequest(http.MethodGet, "/api/users", nil) r := httptest.NewRequest(http.MethodGet, "/api/users", nil)
r = ctx.Set(r, "user", s.admin) r = ctx.Set(r, "user", testCtx.admin)
w := httptest.NewRecorder() w := httptest.NewRecorder()
s.apiServer.Users(w, r) testCtx.apiServer.Users(w, r)
s.Equal(w.Code, http.StatusOK) expected := http.StatusOK
if w.Code != expected {
t.Fatalf("unexpected error code received. expected %d got %d", expected, w.Code)
}
got := []models.User{} got := []models.User{}
err := json.NewDecoder(w.Body).Decode(&got) err := json.NewDecoder(w.Body).Decode(&got)
s.Nil(err) if err != nil {
t.Fatalf("error decoding users data: %v", err)
// We only expect one user
s.Equal(1, len(got))
// And it should be the admin user
s.Equal(s.admin.Id, got[0].Id)
} }
func (s *APISuite) TestCreateUser() { // We only expect one user
expectedUsers := 1
if len(got) != expectedUsers {
t.Fatalf("unexpected number of users returned. expected %d got %d", expectedUsers, len(got))
}
// And it should be the admin user
if testCtx.admin.Id != got[0].Id {
t.Fatalf("unexpected user received. expected %d got %d", testCtx.admin.Id, got[0].Id)
}
}
func TestCreateUser(t *testing.T) {
testCtx := setupTest(t)
payload := &userRequest{ payload := &userRequest{
Username: "foo", Username: "foo",
Password: "bar", Password: "bar",
Role: models.RoleUser, Role: models.RoleUser,
} }
body, err := json.Marshal(payload) body, err := json.Marshal(payload)
s.Nil(err) if err != nil {
t.Fatalf("error marshaling userRequest payload: %v", err)
}
r := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBuffer(body)) r := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBuffer(body))
r.Header.Set("Content-Type", "application/json") r.Header.Set("Content-Type", "application/json")
r = ctx.Set(r, "user", s.admin) r = ctx.Set(r, "user", testCtx.admin)
w := httptest.NewRecorder() w := httptest.NewRecorder()
s.apiServer.Users(w, r) testCtx.apiServer.Users(w, r)
s.Equal(w.Code, http.StatusOK) expected := http.StatusOK
if w.Code != expected {
t.Fatalf("unexpected error code received. expected %d got %d", expected, w.Code)
}
got := &models.User{} got := &models.User{}
err = json.NewDecoder(w.Body).Decode(got) err = json.NewDecoder(w.Body).Decode(got)
s.Nil(err) if err != nil {
s.Equal(got.Username, payload.Username) t.Fatalf("error decoding user payload: %v", err)
s.Equal(got.Role.Slug, payload.Role) }
if got.Username != payload.Username {
t.Fatalf("unexpected username received. expected %s got %s", payload.Username, got.Username)
}
if got.Role.Slug != payload.Role {
t.Fatalf("unexpected role received. expected %s got %s", payload.Role, got.Role.Slug)
}
} }
// TestModifyUser tests that a user with the appropriate access is able to // TestModifyUser tests that a user with the appropriate access is able to
// modify their username and password. // modify their username and password.
func (s *APISuite) TestModifyUser() { func TestModifyUser(t *testing.T) {
unpriviledgedUser := s.createUnpriviledgedUser(models.RoleUser) testCtx := setupTest(t)
unpriviledgedUser := createUnpriviledgedUser(t, models.RoleUser)
newPassword := "new-password" newPassword := "new-password"
newUsername := "new-username" newUsername := "new-username"
payload := userRequest{ payload := userRequest{
@ -82,33 +111,48 @@ func (s *APISuite) TestModifyUser() {
Role: unpriviledgedUser.Role.Slug, Role: unpriviledgedUser.Role.Slug,
} }
body, err := json.Marshal(payload) body, err := json.Marshal(payload)
s.Nil(err) if err != nil {
t.Fatalf("error marshaling userRequest payload: %v", err)
}
url := fmt.Sprintf("/api/users/%d", unpriviledgedUser.Id) url := fmt.Sprintf("/api/users/%d", unpriviledgedUser.Id)
r := httptest.NewRequest(http.MethodPut, url, bytes.NewBuffer(body)) r := httptest.NewRequest(http.MethodPut, url, bytes.NewBuffer(body))
r.Header.Set("Content-Type", "application/json") r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unpriviledgedUser.ApiKey)) r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unpriviledgedUser.ApiKey))
w := httptest.NewRecorder() w := httptest.NewRecorder()
s.apiServer.ServeHTTP(w, r) testCtx.apiServer.ServeHTTP(w, r)
response := &models.User{} response := &models.User{}
err = json.NewDecoder(w.Body).Decode(response) err = json.NewDecoder(w.Body).Decode(response)
s.Nil(err) if err != nil {
s.Equal(w.Code, http.StatusOK) t.Fatalf("error decoding user payload: %v", err)
s.Equal(response.Username, newUsername) }
expected := http.StatusOK
if w.Code != expected {
t.Fatalf("unexpected error code received. expected %d got %d", expected, w.Code)
}
if response.Username != newUsername {
t.Fatalf("unexpected username received. expected %s got %s", newUsername, response.Username)
}
got, err := models.GetUser(unpriviledgedUser.Id) got, err := models.GetUser(unpriviledgedUser.Id)
s.Nil(err) if err != nil {
s.Equal(response.Username, got.Username) t.Fatalf("error getting unpriviledged user: %v", err)
s.Equal(newUsername, got.Username) }
if response.Username != got.Username {
t.Fatalf("unexpected username received. expected %s got %s", response.Username, got.Username)
}
err = bcrypt.CompareHashAndPassword([]byte(got.Hash), []byte(newPassword)) err = bcrypt.CompareHashAndPassword([]byte(got.Hash), []byte(newPassword))
s.Nil(err) if err != nil {
t.Fatalf("incorrect hash received for created user. expected %s got %s", []byte(newPassword), []byte(got.Hash))
}
} }
// TestUnauthorizedListUsers ensures that users without the ModifySystem // TestUnauthorizedListUsers ensures that users without the ModifySystem
// permission are unable to list the users registered in Gophish. // permission are unable to list the users registered in Gophish.
func (s *APISuite) TestUnauthorizedListUsers() { func TestUnauthorizedListUsers(t *testing.T) {
testCtx := setupTest(t)
// First, let's create a standard user which doesn't // First, let's create a standard user which doesn't
// have ModifySystem permissions. // have ModifySystem permissions.
unauthorizedUser := s.createUnpriviledgedUser(models.RoleUser) unauthorizedUser := createUnpriviledgedUser(t, models.RoleUser)
// We'll try to make a request to the various users API endpoints to // We'll try to make a request to the various users API endpoints to
// ensure that they fail. Previously, we could hit the handlers directly // ensure that they fail. Previously, we could hit the handlers directly
// but we need to go through the router for this test to ensure the // but we need to go through the router for this test to ensure the
@ -117,72 +161,99 @@ func (s *APISuite) TestUnauthorizedListUsers() {
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unauthorizedUser.ApiKey)) r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unauthorizedUser.ApiKey))
w := httptest.NewRecorder() w := httptest.NewRecorder()
s.apiServer.ServeHTTP(w, r) testCtx.apiServer.ServeHTTP(w, r)
s.Equal(w.Code, http.StatusForbidden) expected := http.StatusForbidden
if w.Code != expected {
t.Fatalf("unexpected error code received. expected %d got %d", expected, w.Code)
}
} }
// TestUnauthorizedModifyUsers verifies that users without ModifySystem // TestUnauthorizedModifyUsers verifies that users without ModifySystem
// permission (a "standard" user) can only get or modify their own information. // permission (a "standard" user) can only get or modify their own information.
func (s *APISuite) TestUnauthorizedGetUser() { func TestUnauthorizedGetUser(t *testing.T) {
testCtx := setupTest(t)
// First, we'll make sure that a user with the "user" role is unable to // First, we'll make sure that a user with the "user" role is unable to
// get the information of another user (in this case, the main admin). // get the information of another user (in this case, the main admin).
unauthorizedUser := s.createUnpriviledgedUser(models.RoleUser) unauthorizedUser := createUnpriviledgedUser(t, models.RoleUser)
url := fmt.Sprintf("/api/users/%d", s.admin.Id) url := fmt.Sprintf("/api/users/%d", testCtx.admin.Id)
r := httptest.NewRequest(http.MethodGet, url, nil) r := httptest.NewRequest(http.MethodGet, url, nil)
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unauthorizedUser.ApiKey)) r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unauthorizedUser.ApiKey))
w := httptest.NewRecorder() w := httptest.NewRecorder()
s.apiServer.ServeHTTP(w, r) testCtx.apiServer.ServeHTTP(w, r)
s.Equal(w.Code, http.StatusForbidden) expected := http.StatusForbidden
if w.Code != expected {
t.Fatalf("unexpected error code received. expected %d got %d", expected, w.Code)
}
} }
// TestUnauthorizedModifyRole ensures that users without the ModifySystem // TestUnauthorizedModifyRole ensures that users without the ModifySystem
// privilege are unable to modify their own role, preventing a potential // privilege are unable to modify their own role, preventing a potential
// privilege escalation issue. // privilege escalation issue.
func (s *APISuite) TestUnauthorizedSetRole() { func TestUnauthorizedSetRole(t *testing.T) {
unauthorizedUser := s.createUnpriviledgedUser(models.RoleUser) testCtx := setupTest(t)
unauthorizedUser := createUnpriviledgedUser(t, models.RoleUser)
url := fmt.Sprintf("/api/users/%d", unauthorizedUser.Id) url := fmt.Sprintf("/api/users/%d", unauthorizedUser.Id)
payload := &userRequest{ payload := &userRequest{
Username: unauthorizedUser.Username, Username: unauthorizedUser.Username,
Role: models.RoleAdmin, Role: models.RoleAdmin,
} }
body, err := json.Marshal(payload) body, err := json.Marshal(payload)
s.Nil(err) if err != nil {
t.Fatalf("error marshaling userRequest payload: %v", err)
}
r := httptest.NewRequest(http.MethodPut, url, bytes.NewBuffer(body)) r := httptest.NewRequest(http.MethodPut, url, bytes.NewBuffer(body))
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unauthorizedUser.ApiKey)) r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unauthorizedUser.ApiKey))
w := httptest.NewRecorder() w := httptest.NewRecorder()
s.apiServer.ServeHTTP(w, r) testCtx.apiServer.ServeHTTP(w, r)
s.Equal(w.Code, http.StatusBadRequest) expected := http.StatusBadRequest
if w.Code != expected {
t.Fatalf("unexpected error code received. expected %d got %d", expected, w.Code)
}
response := &models.Response{} response := &models.Response{}
err = json.NewDecoder(w.Body).Decode(response) err = json.NewDecoder(w.Body).Decode(response)
s.Nil(err) if err != nil {
s.Equal(response.Message, ErrInsufficientPermission.Error()) t.Fatalf("error decoding response payload: %v", err)
}
if response.Message != ErrInsufficientPermission.Error() {
t.Fatalf("incorrect error received when setting role. expected %s got %s", ErrInsufficientPermission.Error(), response.Message)
}
} }
// TestModifyWithExistingUsername verifies that it's not possible to modify // TestModifyWithExistingUsername verifies that it's not possible to modify
// an user's username to one which already exists. // an user's username to one which already exists.
func (s *APISuite) TestModifyWithExistingUsername() { func TestModifyWithExistingUsername(t *testing.T) {
unauthorizedUser := s.createUnpriviledgedUser(models.RoleUser) testCtx := setupTest(t)
unauthorizedUser := createUnpriviledgedUser(t, models.RoleUser)
payload := &userRequest{ payload := &userRequest{
Username: s.admin.Username, Username: testCtx.admin.Username,
Role: unauthorizedUser.Role.Slug, Role: unauthorizedUser.Role.Slug,
} }
body, err := json.Marshal(payload) body, err := json.Marshal(payload)
s.Nil(err) if err != nil {
t.Fatalf("error marshaling userRequest payload: %v", err)
}
url := fmt.Sprintf("/api/users/%d", unauthorizedUser.Id) url := fmt.Sprintf("/api/users/%d", unauthorizedUser.Id)
r := httptest.NewRequest(http.MethodPut, url, bytes.NewReader(body)) r := httptest.NewRequest(http.MethodPut, url, bytes.NewReader(body))
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unauthorizedUser.ApiKey)) r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unauthorizedUser.ApiKey))
w := httptest.NewRecorder() w := httptest.NewRecorder()
s.apiServer.ServeHTTP(w, r) testCtx.apiServer.ServeHTTP(w, r)
s.Equal(w.Code, http.StatusBadRequest) expected := http.StatusBadRequest
expected := &models.Response{ if w.Code != expected {
t.Fatalf("unexpected error code received. expected %d got %d", expected, w.Code)
}
expectedResponse := &models.Response{
Message: ErrUsernameTaken.Error(), Message: ErrUsernameTaken.Error(),
Success: false, Success: false,
} }
got := &models.Response{} got := &models.Response{}
err = json.NewDecoder(w.Body).Decode(got) err = json.NewDecoder(w.Body).Decode(got)
s.Nil(err) if err != nil {
s.Equal(got.Message, expected.Message) t.Fatalf("error decoding response payload: %v", err)
}
if got.Message != expectedResponse.Message {
t.Fatalf("incorrect error received when setting role. expected %s got %s", expectedResponse.Message, got.Message)
}
} }

View File

@ -1,62 +1,75 @@
package controllers package controllers
import ( import (
"fmt"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath"
"testing" "testing"
"github.com/gophish/gophish/config" "github.com/gophish/gophish/config"
"github.com/gophish/gophish/models" "github.com/gophish/gophish/models"
"github.com/stretchr/testify/suite"
) )
// ControllersSuite is a suite of tests to cover API related functions // testContext is the data required to test API related functions
type ControllersSuite struct { type testContext struct {
suite.Suite
apiKey string apiKey string
config *config.Config config *config.Config
adminServer *httptest.Server adminServer *httptest.Server
phishServer *httptest.Server phishServer *httptest.Server
origPath string
} }
func (s *ControllersSuite) SetupSuite() { func setupTest(t *testing.T) *testContext {
wd, _ := os.Getwd()
fmt.Println(wd)
conf := &config.Config{ conf := &config.Config{
DBName: "sqlite3", DBName: "sqlite3",
DBPath: ":memory:", DBPath: ":memory:",
MigrationsPath: "../db/db_sqlite3/migrations/", MigrationsPath: "../db/db_sqlite3/migrations/",
} }
abs, _ := filepath.Abs("../db/db_sqlite3/migrations/")
fmt.Printf("in controllers_test.go: %s\n", abs)
err := models.Setup(conf) err := models.Setup(conf)
if err != nil { if err != nil {
s.T().Fatalf("Failed creating database: %v", err) t.Fatalf("error setting up database: %v", err)
} }
s.config = conf ctx := &testContext{}
s.Nil(err) ctx.config = conf
// Setup the admin server for use in testing ctx.adminServer = httptest.NewUnstartedServer(NewAdminServer(ctx.config.AdminConf).server.Handler)
s.adminServer = httptest.NewUnstartedServer(NewAdminServer(s.config.AdminConf).server.Handler) ctx.adminServer.Config.Addr = ctx.config.AdminConf.ListenURL
s.adminServer.Config.Addr = s.config.AdminConf.ListenURL ctx.adminServer.Start()
s.adminServer.Start()
// Get the API key to use for these tests // Get the API key to use for these tests
u, err := models.GetUser(1) u, err := models.GetUser(1)
s.Nil(err) if err != nil {
s.apiKey = u.ApiKey t.Fatalf("error getting first user from database: %v", err)
}
ctx.apiKey = u.ApiKey
// Start the phishing server // Start the phishing server
s.phishServer = httptest.NewUnstartedServer(NewPhishingServer(s.config.PhishConf).server.Handler) ctx.phishServer = httptest.NewUnstartedServer(NewPhishingServer(ctx.config.PhishConf).server.Handler)
s.phishServer.Config.Addr = s.config.PhishConf.ListenURL ctx.phishServer.Config.Addr = ctx.config.PhishConf.ListenURL
s.phishServer.Start() ctx.phishServer.Start()
// Move our cwd up to the project root for help with resolving // Move our cwd up to the project root for help with resolving
// static assets // static assets
origPath, _ := os.Getwd()
ctx.origPath = origPath
err = os.Chdir("../") err = os.Chdir("../")
s.Nil(err) if err != nil {
t.Fatalf("error changing directories to setup asset discovery: %v", err)
}
createTestData(t)
return ctx
} }
func (s *ControllersSuite) TearDownTest() { func tearDown(t *testing.T, ctx *testContext) {
campaigns, _ := models.GetCampaigns(1) // Tear down the admin and phishing servers
for _, campaign := range campaigns { ctx.adminServer.Close()
models.DeleteCampaign(campaign.Id) ctx.phishServer.Close()
} // Reset the path for the next test
os.Chdir(ctx.origPath)
} }
func (s *ControllersSuite) SetupTest() { func createTestData(t *testing.T) {
// Add a group // Add a group
group := models.Group{Name: "Test Group"} group := models.Group{Name: "Test Group"}
group.Targets = []models.Target{ group.Targets = []models.Target{
@ -67,12 +80,12 @@ func (s *ControllersSuite) SetupTest() {
models.PostGroup(&group) models.PostGroup(&group)
// Add a template // Add a template
t := models.Template{Name: "Test Template"} template := models.Template{Name: "Test Template"}
t.Subject = "Test subject" template.Subject = "Test subject"
t.Text = "Text text" template.Text = "Text text"
t.HTML = "<html>Test</html>" template.HTML = "<html>Test</html>"
t.UserId = 1 template.UserId = 1
models.PostTemplate(&t) models.PostTemplate(&template)
// Add a landing page // Add a landing page
p := models.Page{Name: "Test Page"} p := models.Page{Name: "Test Page"}
@ -91,20 +104,10 @@ func (s *ControllersSuite) SetupTest() {
// Set the status such that no emails are attempted // Set the status such that no emails are attempted
c := models.Campaign{Name: "Test campaign"} c := models.Campaign{Name: "Test campaign"}
c.UserId = 1 c.UserId = 1
c.Template = t c.Template = template
c.Page = p c.Page = p
c.SMTP = smtp c.SMTP = smtp
c.Groups = []models.Group{group} c.Groups = []models.Group{group}
models.PostCampaign(&c, c.UserId) models.PostCampaign(&c, c.UserId)
c.UpdateStatus(models.CampaignEmailsSent) c.UpdateStatus(models.CampaignEmailsSent)
} }
func (s *ControllersSuite) TearDownSuite() {
// Tear down the admin and phishing servers
s.adminServer.Close()
s.phishServer.Close()
}
func TestControllerSuite(t *testing.T) {
suite.Run(t, new(ControllersSuite))
}

View File

@ -5,22 +5,25 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log"
"net/http" "net/http"
"net/url" "net/url"
"reflect"
"testing"
"github.com/gophish/gophish/config" "github.com/gophish/gophish/config"
"github.com/gophish/gophish/models" "github.com/gophish/gophish/models"
) )
func (s *ControllersSuite) getFirstCampaign() models.Campaign { func getFirstCampaign(t *testing.T) models.Campaign {
campaigns, err := models.GetCampaigns(1) campaigns, err := models.GetCampaigns(1)
s.Nil(err) if err != nil {
t.Fatalf("error getting first campaign from database: %v", err)
}
return campaigns[0] return campaigns[0]
} }
func (s *ControllersSuite) getFirstEmailRequest() models.EmailRequest { func getFirstEmailRequest(t *testing.T) models.EmailRequest {
campaign := s.getFirstCampaign() campaign := getFirstCampaign(t)
req := models.EmailRequest{ req := models.EmailRequest{
TemplateId: campaign.TemplateId, TemplateId: campaign.TemplateId,
Template: campaign.Template, Template: campaign.Template,
@ -33,205 +36,334 @@ func (s *ControllersSuite) getFirstEmailRequest() models.EmailRequest {
FromAddress: campaign.SMTP.FromAddress, FromAddress: campaign.SMTP.FromAddress,
} }
err := models.PostEmailRequest(&req) err := models.PostEmailRequest(&req)
s.Nil(err) if err != nil {
t.Fatalf("error creating email request: %v", err)
}
return req return req
} }
func (s *ControllersSuite) openEmail(rid string) { func openEmail(t *testing.T, ctx *testContext, rid string) {
resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", ctx.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err) if err != nil {
t.Fatalf("error requesting /track endpoint: %v", err)
}
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) got, err := ioutil.ReadAll(resp.Body)
s.Nil(err) if err != nil {
t.Fatalf("error reading response body from /track endpoint: %v", err)
}
expected, err := ioutil.ReadFile("static/images/pixel.png") expected, err := ioutil.ReadFile("static/images/pixel.png")
s.Nil(err) if err != nil {
s.Equal(bytes.Compare(body, expected), 0) t.Fatalf("error reading local transparent pixel: %v", err)
}
if !bytes.Equal(got, expected) {
t.Fatalf("unexpected tracking pixel data received. expected %#v got %#v", expected, got)
}
} }
func (s *ControllersSuite) reportedEmail(rid string) { func openEmail404(t *testing.T, ctx *testContext, rid string) {
resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", ctx.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err) if err != nil {
s.Equal(resp.StatusCode, http.StatusNoContent) t.Fatalf("error requesting /track endpoint: %v", err)
} }
func (s *ControllersSuite) reportEmail404(rid string) {
resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err)
s.Equal(resp.StatusCode, http.StatusNotFound)
}
func (s *ControllersSuite) openEmail404(rid string) {
resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err)
defer resp.Body.Close() defer resp.Body.Close()
s.Nil(err) got := resp.StatusCode
s.Equal(resp.StatusCode, http.StatusNotFound) expected := http.StatusNotFound
if got != expected {
t.Fatalf("invalid status code received for /track endpoint. expected %d got %d", expected, got)
}
} }
func (s *ControllersSuite) clickLink(rid string, expectedHTML string) { func reportedEmail(t *testing.T, ctx *testContext, rid string) {
resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", ctx.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err) if err != nil {
defer resp.Body.Close() t.Fatalf("error requesting /report endpoint: %v", err)
body, err := ioutil.ReadAll(resp.Body) }
s.Nil(err) got := resp.StatusCode
log.Printf("%s\n\n\n", body) expected := http.StatusNoContent
s.Equal(bytes.Compare(body, []byte(expectedHTML)), 0) if got != expected {
t.Fatalf("invalid status code received for /report endpoint. expected %d got %d", expected, got)
}
} }
func (s *ControllersSuite) clickLink404(rid string) { func reportEmail404(t *testing.T, ctx *testContext, rid string) {
resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", ctx.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err) if err != nil {
defer resp.Body.Close() t.Fatalf("error requesting /report endpoint: %v", err)
s.Nil(err) }
s.Equal(resp.StatusCode, http.StatusNotFound) got := resp.StatusCode
expected := http.StatusNotFound
if got != expected {
t.Fatalf("invalid status code received for /report endpoint. expected %d got %d", expected, got)
}
} }
func (s *ControllersSuite) transparencyRequest(r models.Result, rid, path string) { func clickLink(t *testing.T, ctx *testContext, rid string, expectedHTML string) {
resp, err := http.Get(fmt.Sprintf("%s%s?%s=%s", s.phishServer.URL, path, models.RecipientParameter, rid)) resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", ctx.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err) if err != nil {
t.Fatalf("error requesting / endpoint: %v", err)
}
defer resp.Body.Close() defer resp.Body.Close()
s.Equal(resp.StatusCode, http.StatusOK) got, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("error reading payload from / endpoint response: %v", err)
}
if !bytes.Equal(got, []byte(expectedHTML)) {
t.Fatalf("invalid response received from / endpoint. expected %s got %s", got, expectedHTML)
}
}
func clickLink404(t *testing.T, ctx *testContext, rid string) {
resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", ctx.phishServer.URL, models.RecipientParameter, rid))
if err != nil {
t.Fatalf("error requesting / endpoint: %v", err)
}
defer resp.Body.Close()
got := resp.StatusCode
expected := http.StatusNotFound
if got != expected {
t.Fatalf("invalid status code received for / endpoint. expected %d got %d", expected, got)
}
}
func transparencyRequest(t *testing.T, ctx *testContext, r models.Result, rid, path string) {
resp, err := http.Get(fmt.Sprintf("%s%s?%s=%s", ctx.phishServer.URL, path, models.RecipientParameter, rid))
if err != nil {
t.Fatalf("error requesting %s endpoint: %v", path, err)
}
defer resp.Body.Close()
got := resp.StatusCode
expected := http.StatusOK
if got != expected {
t.Fatalf("invalid status code received for / endpoint. expected %d got %d", expected, got)
}
tr := &TransparencyResponse{} tr := &TransparencyResponse{}
err = json.NewDecoder(resp.Body).Decode(tr) err = json.NewDecoder(resp.Body).Decode(tr)
s.Nil(err) if err != nil {
s.Equal(tr.ContactAddress, s.config.ContactAddress) t.Fatalf("error unmarshaling transparency request: %v", err)
s.Equal(tr.SendDate, r.SendDate) }
s.Equal(tr.Server, config.ServerName) expectedResponse := &TransparencyResponse{
ContactAddress: ctx.config.ContactAddress,
SendDate: r.SendDate,
Server: config.ServerName,
}
if !reflect.DeepEqual(tr, expectedResponse) {
t.Fatalf("unexpected transparency response received. expected %v got %v", expectedResponse, tr)
}
} }
func (s *ControllersSuite) TestOpenedPhishingEmail() { func TestOpenedPhishingEmail(t *testing.T) {
campaign := s.getFirstCampaign() ctx := setupTest(t)
defer tearDown(t, ctx)
campaign := getFirstCampaign(t)
result := campaign.Results[0] result := campaign.Results[0]
s.Equal(result.Status, models.StatusSending) if result.Status != models.StatusSending {
t.Fatalf("unexpected result status received. expected %s got %s", models.StatusSending, result.Status)
}
s.openEmail(result.RId) openEmail(t, ctx, result.RId)
campaign = s.getFirstCampaign() campaign = getFirstCampaign(t)
result = campaign.Results[0] result = campaign.Results[0]
lastEvent := campaign.Events[len(campaign.Events)-1] lastEvent := campaign.Events[len(campaign.Events)-1]
s.Equal(result.Status, models.EventOpened) if result.Status != models.EventOpened {
s.Equal(lastEvent.Message, models.EventOpened) t.Fatalf("unexpected result status received. expected %s got %s", models.EventOpened, result.Status)
s.Equal(result.ModifiedDate, lastEvent.Time) }
if lastEvent.Message != models.EventOpened {
t.Fatalf("unexpected event status received. expected %s got %s", lastEvent.Message, models.EventOpened)
}
if result.ModifiedDate != lastEvent.Time {
t.Fatalf("unexpected result modified date received. expected %s got %s", lastEvent.Time, result.ModifiedDate)
}
} }
func (s *ControllersSuite) TestReportedPhishingEmail() { func TestReportedPhishingEmail(t *testing.T) {
campaign := s.getFirstCampaign() ctx := setupTest(t)
defer tearDown(t, ctx)
campaign := getFirstCampaign(t)
result := campaign.Results[0] result := campaign.Results[0]
s.Equal(result.Status, models.StatusSending) if result.Status != models.StatusSending {
t.Fatalf("unexpected result status received. expected %s got %s", models.StatusSending, result.Status)
}
s.reportedEmail(result.RId) reportedEmail(t, ctx, result.RId)
campaign = s.getFirstCampaign() campaign = getFirstCampaign(t)
result = campaign.Results[0] result = campaign.Results[0]
lastEvent := campaign.Events[len(campaign.Events)-1] lastEvent := campaign.Events[len(campaign.Events)-1]
s.Equal(result.Reported, true)
s.Equal(lastEvent.Message, models.EventReported) if result.Reported != true {
s.Equal(result.ModifiedDate, lastEvent.Time) t.Fatalf("unexpected result report status received. expected %v got %v", true, result.Reported)
}
if lastEvent.Message != models.EventReported {
t.Fatalf("unexpected event status received. expected %s got %s", lastEvent.Message, models.EventReported)
}
if result.ModifiedDate != lastEvent.Time {
t.Fatalf("unexpected result modified date received. expected %s got %s", lastEvent.Time, result.ModifiedDate)
}
} }
func (s *ControllersSuite) TestClickedPhishingLinkAfterOpen() { func TestClickedPhishingLinkAfterOpen(t *testing.T) {
campaign := s.getFirstCampaign() ctx := setupTest(t)
defer tearDown(t, ctx)
campaign := getFirstCampaign(t)
result := campaign.Results[0] result := campaign.Results[0]
s.Equal(result.Status, models.StatusSending) if result.Status != models.StatusSending {
t.Fatalf("unexpected result status received. expected %s got %s", models.StatusSending, result.Status)
}
s.openEmail(result.RId) openEmail(t, ctx, result.RId)
s.clickLink(result.RId, campaign.Page.HTML) clickLink(t, ctx, result.RId, campaign.Page.HTML)
campaign = s.getFirstCampaign() campaign = getFirstCampaign(t)
result = campaign.Results[0] result = campaign.Results[0]
lastEvent := campaign.Events[len(campaign.Events)-1] lastEvent := campaign.Events[len(campaign.Events)-1]
s.Equal(result.Status, models.EventClicked) if result.Status != models.EventClicked {
s.Equal(lastEvent.Message, models.EventClicked) t.Fatalf("unexpected result status received. expected %s got %s", models.EventClicked, result.Status)
s.Equal(result.ModifiedDate, lastEvent.Time) }
if lastEvent.Message != models.EventClicked {
t.Fatalf("unexpected event status received. expected %s got %s", lastEvent.Message, models.EventClicked)
}
if result.ModifiedDate != lastEvent.Time {
t.Fatalf("unexpected result modified date received. expected %s got %s", lastEvent.Time, result.ModifiedDate)
}
} }
func (s *ControllersSuite) TestNoRecipientID() { func TestNoRecipientID(t *testing.T) {
resp, err := http.Get(fmt.Sprintf("%s/track", s.phishServer.URL)) ctx := setupTest(t)
s.Nil(err) defer tearDown(t, ctx)
s.Equal(resp.StatusCode, http.StatusNotFound) resp, err := http.Get(fmt.Sprintf("%s/track", ctx.phishServer.URL))
if err != nil {
resp, err = http.Get(s.phishServer.URL) t.Fatalf("error requesting /track endpoint: %v", err)
s.Nil(err) }
s.Equal(resp.StatusCode, http.StatusNotFound) got := resp.StatusCode
expected := http.StatusNotFound
if got != expected {
t.Fatalf("invalid status code received for /track endpoint. expected %d got %d", expected, got)
} }
func (s *ControllersSuite) TestInvalidRecipientID() { resp, err = http.Get(ctx.phishServer.URL)
if err != nil {
t.Fatalf("error requesting /track endpoint: %v", err)
}
got = resp.StatusCode
if got != expected {
t.Fatalf("invalid status code received for / endpoint. expected %d got %d", expected, got)
}
}
func TestInvalidRecipientID(t *testing.T) {
ctx := setupTest(t)
defer tearDown(t, ctx)
rid := "XXXXXXXXXX" rid := "XXXXXXXXXX"
s.openEmail404(rid) openEmail404(t, ctx, rid)
s.clickLink404(rid) clickLink404(t, ctx, rid)
s.reportEmail404(rid) reportEmail404(t, ctx, rid)
} }
func (s *ControllersSuite) TestCompletedCampaignClick() { func TestCompletedCampaignClick(t *testing.T) {
campaign := s.getFirstCampaign() ctx := setupTest(t)
defer tearDown(t, ctx)
campaign := getFirstCampaign(t)
result := campaign.Results[0] result := campaign.Results[0]
s.Equal(result.Status, models.StatusSending) if result.Status != models.StatusSending {
s.openEmail(result.RId) t.Fatalf("unexpected result status received. expected %s got %s", models.StatusSending, result.Status)
}
campaign = s.getFirstCampaign() openEmail(t, ctx, result.RId)
campaign = getFirstCampaign(t)
result = campaign.Results[0] result = campaign.Results[0]
s.Equal(result.Status, models.EventOpened) if result.Status != models.EventOpened {
t.Fatalf("unexpected result status received. expected %s got %s", models.EventOpened, result.Status)
}
models.CompleteCampaign(campaign.Id, 1) models.CompleteCampaign(campaign.Id, 1)
s.openEmail404(result.RId) openEmail404(t, ctx, result.RId)
s.clickLink404(result.RId) clickLink404(t, ctx, result.RId)
campaign = s.getFirstCampaign() campaign = getFirstCampaign(t)
result = campaign.Results[0] result = campaign.Results[0]
s.Equal(result.Status, models.EventOpened) if result.Status != models.EventOpened {
t.Fatalf("unexpected result status received. expected %s got %s", models.EventOpened, result.Status)
}
} }
func (s *ControllersSuite) TestRobotsHandler() { func TestRobotsHandler(t *testing.T) {
expected := []byte("User-agent: *\nDisallow: /\n") ctx := setupTest(t)
resp, err := http.Get(fmt.Sprintf("%s/robots.txt", s.phishServer.URL)) defer tearDown(t, ctx)
s.Nil(err) resp, err := http.Get(fmt.Sprintf("%s/robots.txt", ctx.phishServer.URL))
s.Equal(resp.StatusCode, http.StatusOK) if err != nil {
t.Fatalf("error requesting /robots.txt endpoint: %v", err)
}
defer resp.Body.Close() defer resp.Body.Close()
got := resp.StatusCode
expectedStatus := http.StatusOK
if got != expectedStatus {
t.Fatalf("invalid status code received for /track endpoint. expected %d got %d", expectedStatus, got)
}
expected := []byte("User-agent: *\nDisallow: /\n")
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
s.Nil(err) if err != nil {
s.Equal(bytes.Compare(body, expected), 0) t.Fatalf("error reading response body from /robots.txt endpoint: %v", err)
}
if !bytes.Equal(body, expected) {
t.Fatalf("invalid robots.txt response received. expected %s got %s", expected, body)
}
} }
func (s *ControllersSuite) TestInvalidPreviewID() { func TestInvalidPreviewID(t *testing.T) {
ctx := setupTest(t)
defer tearDown(t, ctx)
bogusRId := fmt.Sprintf("%sbogus", models.PreviewPrefix) bogusRId := fmt.Sprintf("%sbogus", models.PreviewPrefix)
s.openEmail404(bogusRId) openEmail404(t, ctx, bogusRId)
s.clickLink404(bogusRId) clickLink404(t, ctx, bogusRId)
s.reportEmail404(bogusRId) reportEmail404(t, ctx, bogusRId)
} }
func (s *ControllersSuite) TestPreviewTrack() { func TestPreviewTrack(t *testing.T) {
req := s.getFirstEmailRequest() ctx := setupTest(t)
s.openEmail(req.RId) defer tearDown(t, ctx)
req := getFirstEmailRequest(t)
openEmail(t, ctx, req.RId)
} }
func (s *ControllersSuite) TestPreviewClick() { func TestPreviewClick(t *testing.T) {
req := s.getFirstEmailRequest() ctx := setupTest(t)
s.clickLink(req.RId, req.Page.HTML) defer tearDown(t, ctx)
req := getFirstEmailRequest(t)
clickLink(t, ctx, req.RId, req.Page.HTML)
} }
func (s *ControllersSuite) TestInvalidTransparencyRequest() { func TestInvalidTransparencyRequest(t *testing.T) {
ctx := setupTest(t)
defer tearDown(t, ctx)
bogusRId := fmt.Sprintf("bogus%s", TransparencySuffix) bogusRId := fmt.Sprintf("bogus%s", TransparencySuffix)
s.openEmail404(bogusRId) openEmail404(t, ctx, bogusRId)
s.clickLink404(bogusRId) clickLink404(t, ctx, bogusRId)
s.reportEmail404(bogusRId) reportEmail404(t, ctx, bogusRId)
} }
func (s *ControllersSuite) TestTransparencyRequest() { func TestTransparencyRequest(t *testing.T) {
campaign := s.getFirstCampaign() ctx := setupTest(t)
defer tearDown(t, ctx)
campaign := getFirstCampaign(t)
result := campaign.Results[0] result := campaign.Results[0]
rid := fmt.Sprintf("%s%s", result.RId, TransparencySuffix) rid := fmt.Sprintf("%s%s", result.RId, TransparencySuffix)
s.transparencyRequest(result, rid, "/") transparencyRequest(t, ctx, result, rid, "/")
s.transparencyRequest(result, rid, "/track") transparencyRequest(t, ctx, result, rid, "/track")
s.transparencyRequest(result, rid, "/report") transparencyRequest(t, ctx, result, rid, "/report")
// And check with the URL encoded version of a + // And check with the URL encoded version of a +
rid = fmt.Sprintf("%s%s", result.RId, "%2b") rid = fmt.Sprintf("%s%s", result.RId, "%2b")
s.transparencyRequest(result, rid, "/") transparencyRequest(t, ctx, result, rid, "/")
s.transparencyRequest(result, rid, "/track") transparencyRequest(t, ctx, result, rid, "/track")
s.transparencyRequest(result, rid, "/report") transparencyRequest(t, ctx, result, rid, "/report")
} }
func (s *ControllersSuite) TestRedirectTemplating() { func TestRedirectTemplating(t *testing.T) {
ctx := setupTest(t)
defer tearDown(t, ctx)
p := models.Page{ p := models.Page{
Name: "Redirect Page", Name: "Redirect Page",
HTML: "<html>Test</html>", HTML: "<html>Test</html>",
@ -239,7 +371,9 @@ func (s *ControllersSuite) TestRedirectTemplating() {
RedirectURL: "http://example.com/{{.RId}}", RedirectURL: "http://example.com/{{.RId}}",
} }
err := models.PostPage(&p) err := models.PostPage(&p)
s.Nil(err) if err != nil {
t.Fatalf("error posting new page: %v", err)
}
smtp, _ := models.GetSMTP(1, 1) smtp, _ := models.GetSMTP(1, 1)
template, _ := models.GetTemplate(1, 1) template, _ := models.GetTemplate(1, 1)
group, _ := models.GetGroup(1, 1) group, _ := models.GetGroup(1, 1)
@ -251,7 +385,9 @@ func (s *ControllersSuite) TestRedirectTemplating() {
campaign.SMTP = smtp campaign.SMTP = smtp
campaign.Groups = []models.Group{group} campaign.Groups = []models.Group{group}
err = models.PostCampaign(&campaign, campaign.UserId) err = models.PostCampaign(&campaign, campaign.UserId)
s.Nil(err) if err != nil {
t.Fatalf("error creating campaign: %v", err)
}
client := http.Client{ client := http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {
@ -259,12 +395,22 @@ func (s *ControllersSuite) TestRedirectTemplating() {
}, },
} }
result := campaign.Results[0] result := campaign.Results[0]
resp, err := client.PostForm(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, result.RId), url.Values{"username": {"test"}, "password": {"test"}}) resp, err := client.PostForm(fmt.Sprintf("%s/?%s=%s", ctx.phishServer.URL, models.RecipientParameter, result.RId), url.Values{"username": {"test"}, "password": {"test"}})
s.Nil(err) if err != nil {
defer resp.Body.Close() t.Fatalf("error requesting / endpoint: %v", err)
s.Equal(http.StatusFound, resp.StatusCode) }
expectedURL := fmt.Sprintf("http://example.com/%s", result.RId) defer resp.Body.Close()
got, err := resp.Location() got := resp.StatusCode
s.Nil(err) expectedStatus := http.StatusFound
s.Equal(expectedURL, got.String()) if got != expectedStatus {
t.Fatalf("invalid status code received for /track endpoint. expected %d got %d", expectedStatus, got)
}
expectedURL := fmt.Sprintf("http://example.com/%s", result.RId)
gotURL, err := resp.Location()
if err != nil {
t.Fatalf("error getting Location header from response: %v", err)
}
if gotURL.String() != expectedURL {
t.Fatalf("invalid redirect received. expected %s got %s", expectedURL, gotURL)
}
} }

View File

@ -5,106 +5,115 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"testing"
"github.com/PuerkitoBio/goquery" "github.com/PuerkitoBio/goquery"
) )
func (s *ControllersSuite) TestLoginCSRF() { func attemptLogin(t *testing.T, ctx *testContext, client *http.Client, username, password, optionalPath string) *http.Response {
resp, err := http.PostForm(fmt.Sprintf("%s/login", s.adminServer.URL), resp, err := http.Get(fmt.Sprintf("%s/login", ctx.adminServer.URL))
if err != nil {
t.Fatalf("error requesting the /login endpoint: %v", err)
}
got := resp.StatusCode
expected := http.StatusOK
if got != expected {
t.Fatalf("invalid status code received. expected %d got %d", expected, got)
}
doc, err := goquery.NewDocumentFromResponse(resp)
if err != nil {
t.Fatalf("error parsing /login response body")
}
elem := doc.Find("input[name='csrf_token']").First()
token, ok := elem.Attr("value")
if !ok {
t.Fatal("unable to find csrf_token value in login response")
}
if client == nil {
client = &http.Client{}
}
req, err := http.NewRequest("POST", fmt.Sprintf("%s/login%s", ctx.adminServer.URL, optionalPath), strings.NewReader(url.Values{
"username": {username},
"password": {password},
"csrf_token": {token},
}.Encode()))
if err != nil {
t.Fatalf("error creating new /login request: %v", err)
}
req.Header.Set("Cookie", resp.Header.Get("Set-Cookie"))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err = client.Do(req)
if err != nil {
t.Fatalf("error requesting the /login endpoint: %v", err)
}
return resp
}
func TestLoginCSRF(t *testing.T) {
ctx := setupTest(t)
defer tearDown(t, ctx)
resp, err := http.PostForm(fmt.Sprintf("%s/login", ctx.adminServer.URL),
url.Values{ url.Values{
"username": {"admin"}, "username": {"admin"},
"password": {"gophish"}, "password": {"gophish"},
}) })
s.Equal(resp.StatusCode, http.StatusForbidden) if err != nil {
fmt.Println(err) t.Fatalf("error requesting the /login endpoint: %v", err)
} }
func (s *ControllersSuite) TestInvalidCredentials() { got := resp.StatusCode
resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL)) expected := http.StatusForbidden
s.Equal(err, nil) if got != expected {
s.Equal(resp.StatusCode, http.StatusOK) t.Fatalf("invalid status code received. expected %d got %d", expected, got)
}
doc, err := goquery.NewDocumentFromResponse(resp)
s.Equal(err, nil)
elem := doc.Find("input[name='csrf_token']").First()
token, ok := elem.Attr("value")
s.Equal(ok, true)
client := &http.Client{}
req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", s.adminServer.URL), strings.NewReader(url.Values{
"username": {"admin"},
"password": {"invalid"},
"csrf_token": {token},
}.Encode()))
s.Equal(err, nil)
req.Header.Set("Cookie", resp.Header.Get("Set-Cookie"))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err = client.Do(req)
s.Equal(err, nil)
s.Equal(resp.StatusCode, http.StatusUnauthorized)
} }
func (s *ControllersSuite) TestSuccessfulLogin() { func TestInvalidCredentials(t *testing.T) {
resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL)) ctx := setupTest(t)
s.Equal(err, nil) defer tearDown(t, ctx)
s.Equal(resp.StatusCode, http.StatusOK) resp := attemptLogin(t, ctx, nil, "admin", "bogus", "")
got := resp.StatusCode
doc, err := goquery.NewDocumentFromResponse(resp) expected := http.StatusUnauthorized
s.Equal(err, nil) if got != expected {
elem := doc.Find("input[name='csrf_token']").First() t.Fatalf("invalid status code received. expected %d got %d", expected, got)
token, ok := elem.Attr("value") }
s.Equal(ok, true)
client := &http.Client{}
req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", s.adminServer.URL), strings.NewReader(url.Values{
"username": {"admin"},
"password": {"gophish"},
"csrf_token": {token},
}.Encode()))
s.Equal(err, nil)
req.Header.Set("Cookie", resp.Header.Get("Set-Cookie"))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err = client.Do(req)
s.Equal(err, nil)
s.Equal(resp.StatusCode, http.StatusOK)
} }
func (s *ControllersSuite) TestSuccessfulRedirect() { func TestSuccessfulLogin(t *testing.T) {
ctx := setupTest(t)
defer tearDown(t, ctx)
resp := attemptLogin(t, ctx, nil, "admin", "gophish", "")
got := resp.StatusCode
expected := http.StatusOK
if got != expected {
t.Fatalf("invalid status code received. expected %d got %d", expected, got)
}
}
func TestSuccessfulRedirect(t *testing.T) {
ctx := setupTest(t)
defer tearDown(t, ctx)
next := "/campaigns" next := "/campaigns"
resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL))
s.Equal(err, nil)
s.Equal(resp.StatusCode, http.StatusOK)
doc, err := goquery.NewDocumentFromResponse(resp)
s.Equal(err, nil)
elem := doc.Find("input[name='csrf_token']").First()
token, ok := elem.Attr("value")
s.Equal(ok, true)
client := &http.Client{ client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse return http.ErrUseLastResponse
}, }}
resp := attemptLogin(t, ctx, client, "admin", "gophish", fmt.Sprintf("?next=%s", next))
got := resp.StatusCode
expected := http.StatusFound
if got != expected {
t.Fatalf("invalid status code received. expected %d got %d", expected, got)
} }
req, err := http.NewRequest("POST", fmt.Sprintf("%s/login?next=%s", s.adminServer.URL, next), strings.NewReader(url.Values{
"username": {"admin"},
"password": {"gophish"},
"csrf_token": {token},
}.Encode()))
s.Equal(err, nil)
req.Header.Set("Cookie", resp.Header.Get("Set-Cookie"))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
resp, err = client.Do(req)
s.Equal(err, nil)
s.Equal(resp.StatusCode, http.StatusFound)
url, err := resp.Location() url, err := resp.Location()
s.Equal(err, nil) if err != nil {
s.Equal(url.Path, next) t.Fatalf("error parsing response Location header: %v", err)
}
if url.Path != next {
t.Fatalf("unexpected Location header received. expected %s got %s", next, url.Path)
}
} }

View File

@ -8,14 +8,8 @@ import (
"net/textproto" "net/textproto"
"reflect" "reflect"
"testing" "testing"
"github.com/stretchr/testify/suite"
) )
type MailerSuite struct {
suite.Suite
}
func generateMessages(dialer Dialer) []Mail { func generateMessages(dialer Dialer) []Mail {
to := []string{"to@example.com"} to := []string{"to@example.com"}
@ -47,30 +41,30 @@ func newMockErrorSender(err error) *mockSender {
return sender return sender
} }
func (ms *MailerSuite) TestDialHost() { func TestDialHost(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
md := newMockDialer() md := newMockDialer()
md.setDial(md.unreachableDial) md.setDial(md.unreachableDial)
_, err := dialHost(ctx, md) _, err := dialHost(ctx, md)
if _, ok := err.(*ErrMaxConnectAttempts); !ok { if _, ok := err.(*ErrMaxConnectAttempts); !ok {
ms.T().Fatalf("Didn't receive expected ErrMaxConnectAttempts. Got: %s", err) t.Fatalf("Didn't receive expected ErrMaxConnectAttempts. Got: %s", err)
} }
e := err.(*ErrMaxConnectAttempts) e := err.(*ErrMaxConnectAttempts)
if e.underlyingError != errHostUnreachable { if e.underlyingError != errHostUnreachable {
ms.T().Fatalf("Got invalid underlying error. Expected %s Got %s\n", e.underlyingError, errHostUnreachable) t.Fatalf("Got invalid underlying error. Expected %s Got %s\n", e.underlyingError, errHostUnreachable)
} }
if md.dialCount != MaxReconnectAttempts { if md.dialCount != MaxReconnectAttempts {
ms.T().Fatalf("Unexpected number of reconnect attempts. Expected %d, Got %d", MaxReconnectAttempts, md.dialCount) t.Fatalf("Unexpected number of reconnect attempts. Expected %d, Got %d", MaxReconnectAttempts, md.dialCount)
} }
md.setDial(md.defaultDial) md.setDial(md.defaultDial)
_, err = dialHost(ctx, md) _, err = dialHost(ctx, md)
if err != nil { if err != nil {
ms.T().Fatalf("Unexpected error when dialing the mock host: %s", err) t.Fatalf("Unexpected error when dialing the mock host: %s", err)
} }
} }
func (ms *MailerSuite) TestMailWorkerStart() { func TestMailWorkerStart(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -97,16 +91,16 @@ func (ms *MailerSuite) TestMailWorkerStart() {
got = append(got, message) got = append(got, message)
original := messages[idx].(*mockMessage) original := messages[idx].(*mockMessage)
if original.from != message.from { if original.from != message.from {
ms.T().Fatalf("Invalid message received. Expected %s, Got %s", original.from, message.from) t.Fatalf("Invalid message received. Expected %s, Got %s", original.from, message.from)
} }
idx++ idx++
} }
if len(got) != len(messages) { if len(got) != len(messages) {
ms.T().Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), len(messages)) t.Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), len(messages))
} }
} }
func (ms *MailerSuite) TestBackoff() { func TestBackoff(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -139,28 +133,28 @@ func (ms *MailerSuite) TestBackoff() {
// Check that we only sent one message // Check that we only sent one message
expectedCount := 1 expectedCount := 1
if len(got) != expectedCount { if len(got) != expectedCount {
ms.T().Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount) t.Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount)
} }
// Check that it's the correct message // Check that it's the correct message
originalFrom := messages[1].(*mockMessage).from originalFrom := messages[1].(*mockMessage).from
if got[0].from != originalFrom { if got[0].from != originalFrom {
ms.T().Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from) t.Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from)
} }
// Check that the first message performed a backoff // Check that the first message performed a backoff
backoffCount := messages[0].(*mockMessage).backoffCount backoffCount := messages[0].(*mockMessage).backoffCount
if backoffCount != expectedCount { if backoffCount != expectedCount {
ms.T().Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedCount) t.Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedCount)
} }
// Check that there was a reset performed on the sender // Check that there was a reset performed on the sender
if sender.resetCount != expectedCount { if sender.resetCount != expectedCount {
ms.T().Fatalf("Did not receive expected reset. Got resetCount %d, expected %d", sender.resetCount, expectedCount) t.Fatalf("Did not receive expected reset. Got resetCount %d, expected %d", sender.resetCount, expectedCount)
} }
} }
func (ms *MailerSuite) TestPermError() { func TestPermError(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -193,13 +187,13 @@ func (ms *MailerSuite) TestPermError() {
// Check that we only sent one message // Check that we only sent one message
expectedCount := 1 expectedCount := 1
if len(got) != expectedCount { if len(got) != expectedCount {
ms.T().Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount) t.Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount)
} }
// Check that it's the correct message // Check that it's the correct message
originalFrom := messages[1].(*mockMessage).from originalFrom := messages[1].(*mockMessage).from
if got[0].from != originalFrom { if got[0].from != originalFrom {
ms.T().Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from) t.Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from)
} }
message := messages[0].(*mockMessage) message := messages[0].(*mockMessage)
@ -208,21 +202,21 @@ func (ms *MailerSuite) TestPermError() {
expectedBackoffCount := 0 expectedBackoffCount := 0
backoffCount := message.backoffCount backoffCount := message.backoffCount
if backoffCount != expectedBackoffCount { if backoffCount != expectedBackoffCount {
ms.T().Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedCount) t.Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedCount)
} }
// Check that there was a reset performed on the sender // Check that there was a reset performed on the sender
if sender.resetCount != expectedCount { if sender.resetCount != expectedCount {
ms.T().Fatalf("Did not receive expected reset. Got resetCount %d, expected %d", sender.resetCount, expectedCount) t.Fatalf("Did not receive expected reset. Got resetCount %d, expected %d", sender.resetCount, expectedCount)
} }
// Check that the email errored out appropriately // Check that the email errored out appropriately
if !reflect.DeepEqual(message.err, expectedError) { if !reflect.DeepEqual(message.err, expectedError) {
ms.T().Fatalf("Did not received expected error. Got %#v\nExpected %#v", message.err, expectedError) t.Fatalf("Did not received expected error. Got %#v\nExpected %#v", message.err, expectedError)
} }
} }
func (ms *MailerSuite) TestUnknownError() { func TestUnknownError(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -252,13 +246,13 @@ func (ms *MailerSuite) TestUnknownError() {
// Check that we only sent one message // Check that we only sent one message
expectedCount := 1 expectedCount := 1
if len(got) != expectedCount { if len(got) != expectedCount {
ms.T().Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount) t.Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount)
} }
// Check that it's the correct message // Check that it's the correct message
originalFrom := messages[1].(*mockMessage).from originalFrom := messages[1].(*mockMessage).from
if got[0].from != originalFrom { if got[0].from != originalFrom {
ms.T().Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from) t.Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from)
} }
message := messages[0].(*mockMessage) message := messages[0].(*mockMessage)
@ -271,21 +265,17 @@ func (ms *MailerSuite) TestUnknownError() {
expectedBackoffCount := 1 expectedBackoffCount := 1
backoffCount := message.backoffCount backoffCount := message.backoffCount
if backoffCount != expectedBackoffCount { if backoffCount != expectedBackoffCount {
ms.T().Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedBackoffCount) t.Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedBackoffCount)
} }
// Check that the underlying connection was reestablished // Check that the underlying connection was reestablished
expectedDialCount := 2 expectedDialCount := 2
if dialer.dialCount != expectedDialCount { if dialer.dialCount != expectedDialCount {
ms.T().Fatalf("Did not receive expected dial count. Got %d expected %d", dialer.dialCount, expectedDialCount) t.Fatalf("Did not receive expected dial count. Got %d expected %d", dialer.dialCount, expectedDialCount)
} }
// Check that the email errored out appropriately // Check that the email errored out appropriately
if !reflect.DeepEqual(message.err, expectedError) { if !reflect.DeepEqual(message.err, expectedError) {
ms.T().Fatalf("Did not received expected error. Got %#v\nExpected %#v", message.err, expectedError) t.Fatalf("Did not received expected error. Got %#v\nExpected %#v", message.err, expectedError)
} }
} }
func TestMailerSuite(t *testing.T) {
suite.Run(t, new(MailerSuite))
}

View File

@ -9,19 +9,17 @@ import (
"github.com/gophish/gophish/config" "github.com/gophish/gophish/config"
ctx "github.com/gophish/gophish/context" ctx "github.com/gophish/gophish/context"
"github.com/gophish/gophish/models" "github.com/gophish/gophish/models"
"github.com/stretchr/testify/suite"
) )
var successHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var successHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("success")) w.Write([]byte("success"))
}) })
type MiddlewareSuite struct { type testContext struct {
suite.Suite
apiKey string apiKey string
} }
func (s *MiddlewareSuite) SetupSuite() { func setupTest(t *testing.T) *testContext {
conf := &config.Config{ conf := &config.Config{
DBName: "sqlite3", DBName: "sqlite3",
DBPath: ":memory:", DBPath: ":memory:",
@ -29,12 +27,16 @@ func (s *MiddlewareSuite) SetupSuite() {
} }
err := models.Setup(conf) err := models.Setup(conf)
if err != nil { if err != nil {
s.T().Fatalf("Failed creating database: %v", err) t.Fatalf("Failed creating database: %v", err)
} }
// Get the API key to use for these tests // Get the API key to use for these tests
u, err := models.GetUser(1) u, err := models.GetUser(1)
s.Nil(err) if err != nil {
s.apiKey = u.ApiKey t.Fatalf("error getting user: %v", err)
}
ctx := &testContext{}
ctx.apiKey = u.ApiKey
return ctx
} }
// MiddlewarePermissionTest maps an expected HTTP Method to an expected HTTP // MiddlewarePermissionTest maps an expected HTTP Method to an expected HTTP
@ -43,7 +45,8 @@ type MiddlewarePermissionTest map[string]int
// TestEnforceViewOnly ensures that only users with the ModifyObjects // TestEnforceViewOnly ensures that only users with the ModifyObjects
// permission have the ability to send non-GET requests. // permission have the ability to send non-GET requests.
func (s *MiddlewareSuite) TestEnforceViewOnly() { func TestEnforceViewOnly(t *testing.T) {
setupTest(t)
permissionTests := map[string]MiddlewarePermissionTest{ permissionTests := map[string]MiddlewarePermissionTest{
models.RoleAdmin: MiddlewarePermissionTest{ models.RoleAdmin: MiddlewarePermissionTest{
http.MethodGet: http.StatusOK, http.MethodGet: http.StatusOK,
@ -64,7 +67,9 @@ func (s *MiddlewareSuite) TestEnforceViewOnly() {
} }
for r, checks := range permissionTests { for r, checks := range permissionTests {
role, err := models.GetRoleBySlug(r) role, err := models.GetRoleBySlug(r)
s.Nil(err) if err != nil {
t.Fatalf("error getting role by slug: %v", err)
}
for method, expected := range checks { for method, expected := range checks {
req := httptest.NewRequest(method, "/", nil) req := httptest.NewRequest(method, "/", nil)
@ -76,12 +81,16 @@ func (s *MiddlewareSuite) TestEnforceViewOnly() {
}) })
EnforceViewOnly(successHandler).ServeHTTP(response, req) EnforceViewOnly(successHandler).ServeHTTP(response, req)
s.Equal(response.Code, expected) got := response.Code
if got != expected {
t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
}
} }
} }
} }
func (s *MiddlewareSuite) TestRequirePermission() { func TestRequirePermission(t *testing.T) {
setupTest(t)
middleware := RequirePermission(models.PermissionModifySystem) middleware := RequirePermission(models.PermissionModifySystem)
handler := middleware(successHandler) handler := middleware(successHandler)
@ -95,26 +104,37 @@ func (s *MiddlewareSuite) TestRequirePermission() {
response := httptest.NewRecorder() response := httptest.NewRecorder()
// Test that with the requested permission, the request succeeds // Test that with the requested permission, the request succeeds
role, err := models.GetRoleBySlug(role) role, err := models.GetRoleBySlug(role)
s.Nil(err) if err != nil {
t.Fatalf("error getting role by slug: %v", err)
}
req = ctx.Set(req, "user", models.User{ req = ctx.Set(req, "user", models.User{
Role: role, Role: role,
RoleID: role.ID, RoleID: role.ID,
}) })
handler.ServeHTTP(response, req) handler.ServeHTTP(response, req)
s.Equal(response.Code, expected) got := response.Code
if got != expected {
t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
}
} }
} }
func (s *MiddlewareSuite) TestRequireAPIKey() { func TestRequireAPIKey(t *testing.T) {
setupTest(t)
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
response := httptest.NewRecorder() response := httptest.NewRecorder()
// Test that making a request without an API key is denied // Test that making a request without an API key is denied
RequireAPIKey(successHandler).ServeHTTP(response, req) RequireAPIKey(successHandler).ServeHTTP(response, req)
s.Equal(response.Code, http.StatusUnauthorized) expected := http.StatusUnauthorized
got := response.Code
if got != expected {
t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
}
} }
func (s *MiddlewareSuite) TestInvalidAPIKey() { func TestInvalidAPIKey(t *testing.T) {
setupTest(t)
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
query := req.URL.Query() query := req.URL.Query()
query.Set("api_key", "bogus-api-key") query.Set("api_key", "bogus-api-key")
@ -122,18 +142,23 @@ func (s *MiddlewareSuite) TestInvalidAPIKey() {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
response := httptest.NewRecorder() response := httptest.NewRecorder()
RequireAPIKey(successHandler).ServeHTTP(response, req) RequireAPIKey(successHandler).ServeHTTP(response, req)
s.Equal(response.Code, http.StatusUnauthorized) expected := http.StatusUnauthorized
got := response.Code
if got != expected {
t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
}
} }
func (s *MiddlewareSuite) TestBearerToken() { func TestBearerToken(t *testing.T) {
testCtx := setupTest(t)
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.apiKey)) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", testCtx.apiKey))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
response := httptest.NewRecorder() response := httptest.NewRecorder()
RequireAPIKey(successHandler).ServeHTTP(response, req) RequireAPIKey(successHandler).ServeHTTP(response, req)
s.Equal(response.Code, http.StatusOK) expected := http.StatusOK
got := response.Code
if got != expected {
t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
} }
func TestMiddlewareSuite(t *testing.T) {
suite.Run(t, new(MiddlewareSuite))
} }

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"path/filepath"
"time" "time"
"bitbucket.org/liamstask/goose/lib/goose" "bitbucket.org/liamstask/goose/lib/goose"
@ -93,6 +94,8 @@ func Setup(c *config.Config) error {
Env: "production", Env: "production",
Driver: chooseDBDriver(conf.DBName, conf.DBPath), Driver: chooseDBDriver(conf.DBName, conf.DBPath),
} }
abs, _ := filepath.Abs(migrateConf.MigrationsDir)
fmt.Println(abs)
// Get the latest possible migration // Get the latest possible migration
latest, err := goose.GetMostRecentDBVersion(migrateConf.MigrationsDir) latest, err := goose.GetMostRecentDBVersion(migrateConf.MigrationsDir)
if err != nil { if err != nil {

View File

@ -8,28 +8,9 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/gophish/gophish/config"
"github.com/gophish/gophish/models" "github.com/gophish/gophish/models"
"github.com/stretchr/testify/suite"
) )
type UtilSuite struct {
suite.Suite
}
func (s *UtilSuite) SetupSuite() {
conf := &config.Config{
DBName: "sqlite3",
DBPath: ":memory:",
MigrationsPath: "../db/db_sqlite3/migrations/",
}
err := models.Setup(conf)
if err != nil {
s.T().Fatalf("Failed creating database: %v", err)
}
s.Nil(err)
}
func buildCSVRequest(csvPayload string) (*http.Request, error) { func buildCSVRequest(csvPayload string) (*http.Request, error) {
csvHeader := "First Name,Last Name,Email\n" csvHeader := "First Name,Last Name,Email\n"
body := new(bytes.Buffer) body := new(bytes.Buffer)
@ -52,7 +33,7 @@ func buildCSVRequest(csvPayload string) (*http.Request, error) {
return r, nil return r, nil
} }
func (s *UtilSuite) TestParseCSVEmail() { func TestParseCSVEmail(t *testing.T) {
expected := models.Target{ expected := models.Target{
BaseRecipient: models.BaseRecipient{ BaseRecipient: models.BaseRecipient{
FirstName: "John", FirstName: "John",
@ -63,16 +44,19 @@ func (s *UtilSuite) TestParseCSVEmail() {
csvPayload := fmt.Sprintf("%s,%s,<%s>", expected.FirstName, expected.LastName, expected.Email) csvPayload := fmt.Sprintf("%s,%s,<%s>", expected.FirstName, expected.LastName, expected.Email)
r, err := buildCSVRequest(csvPayload) r, err := buildCSVRequest(csvPayload)
s.Nil(err) if err != nil {
t.Fatalf("error building CSV request: %v", err)
}
got, err := ParseCSV(r) got, err := ParseCSV(r)
s.Nil(err) if err != nil {
s.Equal(len(got), 1) t.Fatalf("error parsing CSV: %v", err)
}
expectedLength := 1
if len(got) != expectedLength {
t.Fatalf("invalid number of results received from CSV. expected %d got %d", expectedLength, len(got))
}
if !reflect.DeepEqual(expected, got[0]) { if !reflect.DeepEqual(expected, got[0]) {
s.T().Fatalf("Incorrect targets received. Expected: %#v\nGot: %#v", expected, got) t.Fatalf("Incorrect targets received. Expected: %#v\nGot: %#v", expected, got)
} }
} }
func TestUtilSuite(t *testing.T) {
suite.Run(t, new(UtilSuite))
}

View File

@ -7,16 +7,10 @@ import (
"log" "log"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
) )
type WebhookSuite struct {
suite.Suite
}
type mockSender struct { type mockSender struct {
client *http.Client client *http.Client
} }
@ -33,22 +27,24 @@ func (ms mockSender) Send(endPoint EndPoint, data interface{}) error {
return nil return nil
} }
func (s *WebhookSuite) TestSendMocked() { func TestSendMocked(t *testing.T) {
mcSnd := newMockSender() ms := newMockSender()
endp1 := EndPoint{URL: "http://example.com/a1", Secret: "s1"} endpoint := EndPoint{URL: "http://example.com/a1", Secret: "s1"}
d1 := map[string]string{ data := map[string]string{
"a1": "a11", "a1": "a11",
"a2": "a22", "a2": "a22",
"a3": "a33", "a3": "a33",
} }
err := mcSnd.Send(endp1, d1) err := ms.Send(endpoint, data)
s.Nil(err) if err != nil {
t.Fatalf("error sending data to webhook endpoint: %v", err)
}
} }
func (s *WebhookSuite) TestSendReal() { func TestSendReal(t *testing.T) {
expectedSign := "004b36ca3fcbc01a08b17bf5d4a7e1aa0b10e14f55f3f8bd9acac0c7e8d2635d" expectedSig := "004b36ca3fcbc01a08b17bf5d4a7e1aa0b10e14f55f3f8bd9acac0c7e8d2635d"
secret := "secret456" secret := "secret456"
d1 := map[string]interface{}{ data := map[string]interface{}{
"key1": "val1", "key1": "val1",
"key2": "val2", "key2": "val2",
"key3": "val3", "key3": "val3",
@ -58,37 +54,50 @@ func (s *WebhookSuite) TestSendReal() {
fmt.Println("[test] running the server...") fmt.Println("[test] running the server...")
signStartIdx := len(Sha256Prefix) + 1 signStartIdx := len(Sha256Prefix) + 1
realSignRaw := r.Header.Get(SignatureHeader) sigHeader := r.Header.Get(SignatureHeader)
realSign := realSignRaw[signStartIdx:] gotSig := sigHeader[signStartIdx:]
assert.Equal(s.T(), expectedSign, realSign) if expectedSig != gotSig {
t.Fatalf("invalid signature received. expected %s got %s", expectedSig, gotSig)
}
contTypeJsonHeader := r.Header.Get("Content-Type") ct := r.Header.Get("Content-Type")
assert.Equal(s.T(), contTypeJsonHeader, "application/json") expectedCT := "application/json"
if ct != expectedCT {
t.Fatalf("invalid content type. expected %s got %s", ct, expectedCT)
}
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
s.Nil(err) if err != nil {
t.Fatalf("error reading JSON body from webhook request: %v", err)
}
var d2 map[string]interface{} var payload map[string]interface{}
err = json.Unmarshal(body, &d2) err = json.Unmarshal(body, &payload)
s.Nil(err) if err != nil {
assert.Equal(s.T(), d1, d2) t.Fatalf("error unmarshaling webhook payload: %v", err)
}
if !reflect.DeepEqual(data, payload) {
t.Fatalf("invalid payload received. expected %#v got %#v", data, payload)
}
})) }))
defer ts.Close() defer ts.Close()
endp1 := EndPoint{URL: ts.URL, Secret: secret} endp1 := EndPoint{URL: ts.URL, Secret: secret}
err := Send(endp1, d1) err := Send(endp1, data)
s.Nil(err) if err != nil {
t.Fatalf("error sending data to webhook endpoint: %v", err)
}
} }
func (s *WebhookSuite) TestSignature() { func TestSignature(t *testing.T) {
secret := "secret123" secret := "secret123"
payload := []byte("some payload456") payload := []byte("some payload456")
expectedSign := "ab7844c1e9149f8dc976c4188a72163c005930f3c2266a163ffe434230bdf761" expected := "ab7844c1e9149f8dc976c4188a72163c005930f3c2266a163ffe434230bdf761"
realSign, err := sign(secret, payload) got, err := sign(secret, payload)
s.Nil(err) if err != nil {
assert.Equal(s.T(), expectedSign, realSign) t.Fatalf("error signing payload: %v", err)
}
if expected != got {
t.Fatalf("invalid signature received. expected %s got %s", expected, got)
} }
func TestWebhookSuite(t *testing.T) {
suite.Run(t, new(WebhookSuite))
} }

View File

@ -1,18 +1,18 @@
package worker package worker
import ( import (
"testing"
"github.com/gophish/gophish/config" "github.com/gophish/gophish/config"
"github.com/gophish/gophish/models" "github.com/gophish/gophish/models"
"github.com/stretchr/testify/suite"
) )
// WorkerSuite is a suite of tests to cover API related functions // testContext is context to cover API related functions
type WorkerSuite struct { type testContext struct {
suite.Suite
config *config.Config config *config.Config
} }
func (s *WorkerSuite) SetupSuite() { func setupTest(t *testing.T) *testContext {
conf := &config.Config{ conf := &config.Config{
DBName: "sqlite3", DBName: "sqlite3",
DBPath: ":memory:", DBPath: ":memory:",
@ -20,21 +20,15 @@ func (s *WorkerSuite) SetupSuite() {
} }
err := models.Setup(conf) err := models.Setup(conf)
if err != nil { if err != nil {
s.T().Fatalf("Failed creating database: %v", err) t.Fatalf("Failed creating database: %v", err)
} }
s.config = conf ctx := &testContext{}
s.Nil(err) ctx.config = conf
return ctx
} }
func (s *WorkerSuite) TearDownTest() { func createTestData(t *testing.T, ctx *testContext) {
campaigns, _ := models.GetCampaigns(1) ctx.config.TestFlag = true
for _, campaign := range campaigns {
models.DeleteCampaign(campaign.Id)
}
}
func (s *WorkerSuite) SetupTest() {
s.config.TestFlag = true
// Add a group // Add a group
group := models.Group{Name: "Test Group"} group := models.Group{Name: "Test Group"}
group.Targets = []models.Target{ group.Targets = []models.Target{
@ -45,12 +39,12 @@ func (s *WorkerSuite) SetupTest() {
models.PostGroup(&group) models.PostGroup(&group)
// Add a template // Add a template
t := models.Template{Name: "Test Template"} template := models.Template{Name: "Test Template"}
t.Subject = "Test subject" template.Subject = "Test subject"
t.Text = "Text text" template.Text = "Text text"
t.HTML = "<html>Test</html>" template.HTML = "<html>Test</html>"
t.UserId = 1 template.UserId = 1
models.PostTemplate(&t) models.PostTemplate(&template)
// Add a landing page // Add a landing page
p := models.Page{Name: "Test Page"} p := models.Page{Name: "Test Page"}
@ -69,7 +63,7 @@ func (s *WorkerSuite) SetupTest() {
// Set the status such that no emails are attempted // Set the status such that no emails are attempted
c := models.Campaign{Name: "Test campaign"} c := models.Campaign{Name: "Test campaign"}
c.UserId = 1 c.UserId = 1
c.Template = t c.Template = template
c.Page = p c.Page = p
c.SMTP = smtp c.SMTP = smtp
c.Groups = []models.Group{group} c.Groups = []models.Group{group}