From be459e47bf23077bcc3c3e717527b5660132f50b Mon Sep 17 00:00:00 2001 From: Jordan Wright Date: Sat, 1 Feb 2020 21:44:50 -0600 Subject: [PATCH] Refactoring tests to remove stretchr/testify dependency --- config/config_test.go | 56 +++-- controllers/api/api_test.go | 85 +++---- controllers/api/user_test.go | 179 +++++++++----- controllers/controllers_test.go | 87 +++---- controllers/phish_test.go | 418 +++++++++++++++++++++----------- controllers/route_test.go | 171 ++++++------- mailer/mailer_test.go | 60 ++--- middleware/middleware_test.go | 73 ++++-- models/models.go | 3 + util/util_test.go | 40 +-- webhook/webhook_test.go | 83 ++++--- worker/worker_test.go | 42 ++-- 12 files changed, 765 insertions(+), 532 deletions(-) diff --git a/config/config_test.go b/config/config_test.go index 26f4fd4f..e0a553e8 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -4,16 +4,10 @@ import ( "encoding/json" "io/ioutil" "os" + "reflect" "testing" - - "github.com/stretchr/testify/suite" ) -type ConfigSuite struct { - suite.Suite - ConfigFile *os.File -} - var validConfig = []byte(`{ "admin_server": { "listen_url": "127.0.0.1:3333", @@ -33,36 +27,48 @@ var validConfig = []byte(`{ "contact_address": "" }`) -func (s *ConfigSuite) SetupTest() { +func createTemporaryConfig(t *testing.T) *os.File { f, err := ioutil.TempFile("", "gophish-config") - s.Nil(err) - s.ConfigFile = f + if err != nil { + t.Fatalf("unable to create temporary config: %v", err) + } + return f } -func (s *ConfigSuite) TearDownTest() { - err := s.ConfigFile.Close() - s.Nil(err) +func removeTemporaryConfig(t *testing.T, f *os.File) { + err := f.Close() + if err != nil { + t.Fatalf("unable to remove temporary config: %v", err) + } } -func (s *ConfigSuite) TestLoadConfig() { - _, err := s.ConfigFile.Write(validConfig) - s.Nil(err) +func TestLoadConfig(t *testing.T) { + f := createTemporaryConfig(t) + defer removeTemporaryConfig(t, f) + _, err := f.Write(validConfig) + if err != nil { + t.Fatalf("error writing config to temporary file: %v", err) + } // Load the valid config - conf, err := LoadConfig(s.ConfigFile.Name()) - s.Nil(err) + conf, err := LoadConfig(f.Name()) + if err != nil { + t.Fatalf("error loading config from temporary file: %v", err) + } expectedConfig := &Config{} err = json.Unmarshal(validConfig, &expectedConfig) - s.Nil(err) + if err != nil { + t.Fatalf("error unmarshaling config: %v", err) + } expectedConfig.MigrationsPath = expectedConfig.MigrationsPath + expectedConfig.DBName expectedConfig.TestFlag = false - s.Equal(expectedConfig, conf) + if !reflect.DeepEqual(expectedConfig, conf) { + t.Fatalf("invalid config received. expected %#v got %#v", expectedConfig, conf) + } // Load an invalid config conf, err = LoadConfig("bogusfile") - s.NotNil(err) -} - -func TestConfigSuite(t *testing.T) { - suite.Run(t, new(ConfigSuite)) + if err == nil { + t.Fatalf("expected error when loading invalid config, but got %v", err) + } } diff --git a/controllers/api/api_test.go b/controllers/api/api_test.go index 09339799..b5dac5fb 100644 --- a/controllers/api/api_test.go +++ b/controllers/api/api_test.go @@ -6,23 +6,20 @@ import ( "fmt" "net/http" "net/http/httptest" - "os" "testing" "github.com/gophish/gophish/config" "github.com/gophish/gophish/models" - "github.com/stretchr/testify/suite" ) -type APISuite struct { - suite.Suite +type testContext struct { apiKey string config *config.Config apiServer *Server admin models.User } -func (s *APISuite) SetupSuite() { +func setupTest(t *testing.T) *testContext { conf := &config.Config{ DBName: "sqlite3", DBPath: ":memory:", @@ -30,39 +27,34 @@ func (s *APISuite) SetupSuite() { } err := models.Setup(conf) if err != nil { - s.T().Fatalf("Failed creating database: %v", err) + t.Fatalf("Failed creating database: %v", err) } - s.config = conf - s.Nil(err) + ctx := &testContext{} + ctx.config = conf // Get the API key to use for these tests u, err := models.GetUser(1) - s.Nil(err) - s.apiKey = u.ApiKey - s.admin = u - // Move our cwd up to the project root for help with resolving - // static assets - err = os.Chdir("../") - s.Nil(err) - s.apiServer = NewServer() + if err != nil { + t.Fatalf("error getting admin user: %v", err) + } + ctx.apiKey = u.ApiKey + ctx.admin = u + ctx.apiServer = NewServer() + return ctx } -func (s *APISuite) TearDownTest() { - campaigns, _ := models.GetCampaigns(1) - for _, campaign := range campaigns { - models.DeleteCampaign(campaign.Id) - } +func tearDown(t *testing.T, ctx *testContext) { // Cleanup all users except the original admin - users, _ := models.GetUsers() - for _, user := range users { - if user.Id == 1 { - continue - } - err := models.DeleteUser(user.Id) - s.Nil(err) - } + // users, _ := models.GetUsers() + // for _, user := range users { + // if user.Id == 1 { + // continue + // } + // err := models.DeleteUser(user.Id) + // s.Nil(err) + // } } -func (s *APISuite) SetupTest() { +func createTestData(t *testing.T) { // Add a group group := models.Group{Name: "Test Group"} group.Targets = []models.Target{ @@ -73,12 +65,12 @@ func (s *APISuite) SetupTest() { models.PostGroup(&group) // Add a template - t := models.Template{Name: "Test Template"} - t.Subject = "Test subject" - t.Text = "Text text" - t.HTML = "Test" - t.UserId = 1 - models.PostTemplate(&t) + template := models.Template{Name: "Test Template"} + template.Subject = "Test subject" + template.Text = "Text text" + template.HTML = "Test" + template.UserId = 1 + models.PostTemplate(&template) // Add a landing page p := models.Page{Name: "Test Page"} @@ -97,7 +89,7 @@ func (s *APISuite) SetupTest() { // Set the status such that no emails are attempted c := models.Campaign{Name: "Test campaign"} c.UserId = 1 - c.Template = t + c.Template = template c.Page = p c.SMTP = smtp c.Groups = []models.Group{group} @@ -105,12 +97,13 @@ func (s *APISuite) SetupTest() { c.UpdateStatus(models.CampaignEmailsSent) } -func (s *APISuite) TestSiteImportBaseHref() { +func TestSiteImportBaseHref(t *testing.T) { + ctx := setupTest(t) h := "" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, h) })) - hr := fmt.Sprintf("\n", ts.URL) + expected := fmt.Sprintf("\n", ts.URL) defer ts.Close() req := httptest.NewRequest(http.MethodPost, "/api/import/site", bytes.NewBuffer([]byte(fmt.Sprintf(` @@ -121,13 +114,13 @@ func (s *APISuite) TestSiteImportBaseHref() { `, ts.URL)))) req.Header.Set("Content-Type", "application/json") response := httptest.NewRecorder() - s.apiServer.ImportSite(response, req) + ctx.apiServer.ImportSite(response, req) cs := cloneResponse{} err := json.NewDecoder(response.Body).Decode(&cs) - s.Nil(err) - s.Equal(cs.HTML, hr) -} - -func TestAPISuite(t *testing.T) { - suite.Run(t, new(APISuite)) + if err != nil { + t.Fatalf("error decoding response: %v", err) + } + if cs.HTML != expected { + t.Fatalf("unexpected response received. expected %s got %s", expected, cs.HTML) + } } diff --git a/controllers/api/user_test.go b/controllers/api/user_test.go index 65dc0255..fa49a6f1 100644 --- a/controllers/api/user_test.go +++ b/controllers/api/user_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "testing" "golang.org/x/crypto/bcrypt" @@ -13,9 +14,11 @@ import ( "github.com/gophish/gophish/models" ) -func (s *APISuite) createUnpriviledgedUser(slug string) *models.User { +func createUnpriviledgedUser(t *testing.T, slug string) *models.User { role, err := models.GetRoleBySlug(slug) - s.Nil(err) + if err != nil { + t.Fatalf("error getting role by slug: %v", err) + } unauthorizedUser := &models.User{ Username: "foo", Hash: "bar", @@ -24,56 +27,82 @@ func (s *APISuite) createUnpriviledgedUser(slug string) *models.User { RoleID: role.ID, } err = models.PutUser(unauthorizedUser) - s.Nil(err) + if err != nil { + t.Fatalf("error saving unpriviledged user: %v", err) + } return unauthorizedUser } -func (s *APISuite) TestGetUsers() { +func TestGetUsers(t *testing.T) { + testCtx := setupTest(t) r := httptest.NewRequest(http.MethodGet, "/api/users", nil) - r = ctx.Set(r, "user", s.admin) + r = ctx.Set(r, "user", testCtx.admin) w := httptest.NewRecorder() - s.apiServer.Users(w, r) - s.Equal(w.Code, http.StatusOK) + testCtx.apiServer.Users(w, r) + expected := http.StatusOK + if w.Code != expected { + t.Fatalf("unexpected error code received. expected %d got %d", expected, w.Code) + } got := []models.User{} err := json.NewDecoder(w.Body).Decode(&got) - s.Nil(err) + if err != nil { + t.Fatalf("error decoding users data: %v", err) + } // We only expect one user - s.Equal(1, len(got)) + expectedUsers := 1 + if len(got) != expectedUsers { + t.Fatalf("unexpected number of users returned. expected %d got %d", expectedUsers, len(got)) + } // And it should be the admin user - s.Equal(s.admin.Id, got[0].Id) + if testCtx.admin.Id != got[0].Id { + t.Fatalf("unexpected user received. expected %d got %d", testCtx.admin.Id, got[0].Id) + } } -func (s *APISuite) TestCreateUser() { +func TestCreateUser(t *testing.T) { + testCtx := setupTest(t) payload := &userRequest{ Username: "foo", Password: "bar", Role: models.RoleUser, } body, err := json.Marshal(payload) - s.Nil(err) + if err != nil { + t.Fatalf("error marshaling userRequest payload: %v", err) + } r := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBuffer(body)) r.Header.Set("Content-Type", "application/json") - r = ctx.Set(r, "user", s.admin) + r = ctx.Set(r, "user", testCtx.admin) w := httptest.NewRecorder() - s.apiServer.Users(w, r) - s.Equal(w.Code, http.StatusOK) + testCtx.apiServer.Users(w, r) + expected := http.StatusOK + if w.Code != expected { + t.Fatalf("unexpected error code received. expected %d got %d", expected, w.Code) + } got := &models.User{} err = json.NewDecoder(w.Body).Decode(got) - s.Nil(err) - s.Equal(got.Username, payload.Username) - s.Equal(got.Role.Slug, payload.Role) + if err != nil { + t.Fatalf("error decoding user payload: %v", err) + } + if got.Username != payload.Username { + t.Fatalf("unexpected username received. expected %s got %s", payload.Username, got.Username) + } + if got.Role.Slug != payload.Role { + t.Fatalf("unexpected role received. expected %s got %s", payload.Role, got.Role.Slug) + } } // TestModifyUser tests that a user with the appropriate access is able to // modify their username and password. -func (s *APISuite) TestModifyUser() { - unpriviledgedUser := s.createUnpriviledgedUser(models.RoleUser) +func TestModifyUser(t *testing.T) { + testCtx := setupTest(t) + unpriviledgedUser := createUnpriviledgedUser(t, models.RoleUser) newPassword := "new-password" newUsername := "new-username" payload := userRequest{ @@ -82,33 +111,48 @@ func (s *APISuite) TestModifyUser() { Role: unpriviledgedUser.Role.Slug, } body, err := json.Marshal(payload) - s.Nil(err) + if err != nil { + t.Fatalf("error marshaling userRequest payload: %v", err) + } url := fmt.Sprintf("/api/users/%d", unpriviledgedUser.Id) r := httptest.NewRequest(http.MethodPut, url, bytes.NewBuffer(body)) r.Header.Set("Content-Type", "application/json") r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unpriviledgedUser.ApiKey)) w := httptest.NewRecorder() - s.apiServer.ServeHTTP(w, r) + testCtx.apiServer.ServeHTTP(w, r) response := &models.User{} err = json.NewDecoder(w.Body).Decode(response) - s.Nil(err) - s.Equal(w.Code, http.StatusOK) - s.Equal(response.Username, newUsername) + if err != nil { + t.Fatalf("error decoding user payload: %v", err) + } + expected := http.StatusOK + if w.Code != expected { + t.Fatalf("unexpected error code received. expected %d got %d", expected, w.Code) + } + if response.Username != newUsername { + t.Fatalf("unexpected username received. expected %s got %s", newUsername, response.Username) + } got, err := models.GetUser(unpriviledgedUser.Id) - s.Nil(err) - s.Equal(response.Username, got.Username) - s.Equal(newUsername, got.Username) + if err != nil { + t.Fatalf("error getting unpriviledged user: %v", err) + } + if response.Username != got.Username { + t.Fatalf("unexpected username received. expected %s got %s", response.Username, got.Username) + } err = bcrypt.CompareHashAndPassword([]byte(got.Hash), []byte(newPassword)) - s.Nil(err) + if err != nil { + t.Fatalf("incorrect hash received for created user. expected %s got %s", []byte(newPassword), []byte(got.Hash)) + } } // TestUnauthorizedListUsers ensures that users without the ModifySystem // permission are unable to list the users registered in Gophish. -func (s *APISuite) TestUnauthorizedListUsers() { +func TestUnauthorizedListUsers(t *testing.T) { + testCtx := setupTest(t) // First, let's create a standard user which doesn't // have ModifySystem permissions. - unauthorizedUser := s.createUnpriviledgedUser(models.RoleUser) + unauthorizedUser := createUnpriviledgedUser(t, models.RoleUser) // We'll try to make a request to the various users API endpoints to // ensure that they fail. Previously, we could hit the handlers directly // but we need to go through the router for this test to ensure the @@ -117,72 +161,99 @@ func (s *APISuite) TestUnauthorizedListUsers() { r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unauthorizedUser.ApiKey)) w := httptest.NewRecorder() - s.apiServer.ServeHTTP(w, r) - s.Equal(w.Code, http.StatusForbidden) + testCtx.apiServer.ServeHTTP(w, r) + expected := http.StatusForbidden + if w.Code != expected { + t.Fatalf("unexpected error code received. expected %d got %d", expected, w.Code) + } } // TestUnauthorizedModifyUsers verifies that users without ModifySystem // permission (a "standard" user) can only get or modify their own information. -func (s *APISuite) TestUnauthorizedGetUser() { +func TestUnauthorizedGetUser(t *testing.T) { + testCtx := setupTest(t) // First, we'll make sure that a user with the "user" role is unable to // get the information of another user (in this case, the main admin). - unauthorizedUser := s.createUnpriviledgedUser(models.RoleUser) - url := fmt.Sprintf("/api/users/%d", s.admin.Id) + unauthorizedUser := createUnpriviledgedUser(t, models.RoleUser) + url := fmt.Sprintf("/api/users/%d", testCtx.admin.Id) r := httptest.NewRequest(http.MethodGet, url, nil) r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unauthorizedUser.ApiKey)) w := httptest.NewRecorder() - s.apiServer.ServeHTTP(w, r) - s.Equal(w.Code, http.StatusForbidden) + testCtx.apiServer.ServeHTTP(w, r) + expected := http.StatusForbidden + if w.Code != expected { + t.Fatalf("unexpected error code received. expected %d got %d", expected, w.Code) + } } // TestUnauthorizedModifyRole ensures that users without the ModifySystem // privilege are unable to modify their own role, preventing a potential // privilege escalation issue. -func (s *APISuite) TestUnauthorizedSetRole() { - unauthorizedUser := s.createUnpriviledgedUser(models.RoleUser) +func TestUnauthorizedSetRole(t *testing.T) { + testCtx := setupTest(t) + unauthorizedUser := createUnpriviledgedUser(t, models.RoleUser) url := fmt.Sprintf("/api/users/%d", unauthorizedUser.Id) payload := &userRequest{ Username: unauthorizedUser.Username, Role: models.RoleAdmin, } body, err := json.Marshal(payload) - s.Nil(err) + if err != nil { + t.Fatalf("error marshaling userRequest payload: %v", err) + } r := httptest.NewRequest(http.MethodPut, url, bytes.NewBuffer(body)) r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unauthorizedUser.ApiKey)) w := httptest.NewRecorder() - s.apiServer.ServeHTTP(w, r) - s.Equal(w.Code, http.StatusBadRequest) + testCtx.apiServer.ServeHTTP(w, r) + expected := http.StatusBadRequest + if w.Code != expected { + t.Fatalf("unexpected error code received. expected %d got %d", expected, w.Code) + } response := &models.Response{} err = json.NewDecoder(w.Body).Decode(response) - s.Nil(err) - s.Equal(response.Message, ErrInsufficientPermission.Error()) + if err != nil { + t.Fatalf("error decoding response payload: %v", err) + } + if response.Message != ErrInsufficientPermission.Error() { + t.Fatalf("incorrect error received when setting role. expected %s got %s", ErrInsufficientPermission.Error(), response.Message) + } } // TestModifyWithExistingUsername verifies that it's not possible to modify // an user's username to one which already exists. -func (s *APISuite) TestModifyWithExistingUsername() { - unauthorizedUser := s.createUnpriviledgedUser(models.RoleUser) +func TestModifyWithExistingUsername(t *testing.T) { + testCtx := setupTest(t) + unauthorizedUser := createUnpriviledgedUser(t, models.RoleUser) payload := &userRequest{ - Username: s.admin.Username, + Username: testCtx.admin.Username, Role: unauthorizedUser.Role.Slug, } body, err := json.Marshal(payload) - s.Nil(err) + if err != nil { + t.Fatalf("error marshaling userRequest payload: %v", err) + } url := fmt.Sprintf("/api/users/%d", unauthorizedUser.Id) r := httptest.NewRequest(http.MethodPut, url, bytes.NewReader(body)) r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", unauthorizedUser.ApiKey)) w := httptest.NewRecorder() - s.apiServer.ServeHTTP(w, r) - s.Equal(w.Code, http.StatusBadRequest) - expected := &models.Response{ + testCtx.apiServer.ServeHTTP(w, r) + expected := http.StatusBadRequest + if w.Code != expected { + t.Fatalf("unexpected error code received. expected %d got %d", expected, w.Code) + } + expectedResponse := &models.Response{ Message: ErrUsernameTaken.Error(), Success: false, } got := &models.Response{} err = json.NewDecoder(w.Body).Decode(got) - s.Nil(err) - s.Equal(got.Message, expected.Message) + if err != nil { + t.Fatalf("error decoding response payload: %v", err) + } + if got.Message != expectedResponse.Message { + t.Fatalf("incorrect error received when setting role. expected %s got %s", expectedResponse.Message, got.Message) + } } diff --git a/controllers/controllers_test.go b/controllers/controllers_test.go index de0a7366..edbe695f 100644 --- a/controllers/controllers_test.go +++ b/controllers/controllers_test.go @@ -1,62 +1,75 @@ package controllers import ( + "fmt" "net/http/httptest" "os" + "path/filepath" "testing" "github.com/gophish/gophish/config" "github.com/gophish/gophish/models" - "github.com/stretchr/testify/suite" ) -// ControllersSuite is a suite of tests to cover API related functions -type ControllersSuite struct { - suite.Suite +// testContext is the data required to test API related functions +type testContext struct { apiKey string config *config.Config adminServer *httptest.Server phishServer *httptest.Server + origPath string } -func (s *ControllersSuite) SetupSuite() { +func setupTest(t *testing.T) *testContext { + wd, _ := os.Getwd() + fmt.Println(wd) conf := &config.Config{ DBName: "sqlite3", DBPath: ":memory:", MigrationsPath: "../db/db_sqlite3/migrations/", } + abs, _ := filepath.Abs("../db/db_sqlite3/migrations/") + fmt.Printf("in controllers_test.go: %s\n", abs) err := models.Setup(conf) if err != nil { - s.T().Fatalf("Failed creating database: %v", err) + t.Fatalf("error setting up database: %v", err) } - s.config = conf - s.Nil(err) - // Setup the admin server for use in testing - s.adminServer = httptest.NewUnstartedServer(NewAdminServer(s.config.AdminConf).server.Handler) - s.adminServer.Config.Addr = s.config.AdminConf.ListenURL - s.adminServer.Start() + ctx := &testContext{} + ctx.config = conf + ctx.adminServer = httptest.NewUnstartedServer(NewAdminServer(ctx.config.AdminConf).server.Handler) + ctx.adminServer.Config.Addr = ctx.config.AdminConf.ListenURL + ctx.adminServer.Start() // Get the API key to use for these tests u, err := models.GetUser(1) - s.Nil(err) - s.apiKey = u.ApiKey + if err != nil { + t.Fatalf("error getting first user from database: %v", err) + } + ctx.apiKey = u.ApiKey // Start the phishing server - s.phishServer = httptest.NewUnstartedServer(NewPhishingServer(s.config.PhishConf).server.Handler) - s.phishServer.Config.Addr = s.config.PhishConf.ListenURL - s.phishServer.Start() + ctx.phishServer = httptest.NewUnstartedServer(NewPhishingServer(ctx.config.PhishConf).server.Handler) + ctx.phishServer.Config.Addr = ctx.config.PhishConf.ListenURL + ctx.phishServer.Start() // Move our cwd up to the project root for help with resolving // static assets + origPath, _ := os.Getwd() + ctx.origPath = origPath err = os.Chdir("../") - s.Nil(err) -} - -func (s *ControllersSuite) TearDownTest() { - campaigns, _ := models.GetCampaigns(1) - for _, campaign := range campaigns { - models.DeleteCampaign(campaign.Id) + if err != nil { + t.Fatalf("error changing directories to setup asset discovery: %v", err) } + createTestData(t) + return ctx } -func (s *ControllersSuite) SetupTest() { +func tearDown(t *testing.T, ctx *testContext) { + // Tear down the admin and phishing servers + ctx.adminServer.Close() + ctx.phishServer.Close() + // Reset the path for the next test + os.Chdir(ctx.origPath) +} + +func createTestData(t *testing.T) { // Add a group group := models.Group{Name: "Test Group"} group.Targets = []models.Target{ @@ -67,12 +80,12 @@ func (s *ControllersSuite) SetupTest() { models.PostGroup(&group) // Add a template - t := models.Template{Name: "Test Template"} - t.Subject = "Test subject" - t.Text = "Text text" - t.HTML = "Test" - t.UserId = 1 - models.PostTemplate(&t) + template := models.Template{Name: "Test Template"} + template.Subject = "Test subject" + template.Text = "Text text" + template.HTML = "Test" + template.UserId = 1 + models.PostTemplate(&template) // Add a landing page p := models.Page{Name: "Test Page"} @@ -91,20 +104,10 @@ func (s *ControllersSuite) SetupTest() { // Set the status such that no emails are attempted c := models.Campaign{Name: "Test campaign"} c.UserId = 1 - c.Template = t + c.Template = template c.Page = p c.SMTP = smtp c.Groups = []models.Group{group} models.PostCampaign(&c, c.UserId) c.UpdateStatus(models.CampaignEmailsSent) } - -func (s *ControllersSuite) TearDownSuite() { - // Tear down the admin and phishing servers - s.adminServer.Close() - s.phishServer.Close() -} - -func TestControllerSuite(t *testing.T) { - suite.Run(t, new(ControllersSuite)) -} diff --git a/controllers/phish_test.go b/controllers/phish_test.go index c1b53c54..87490bed 100644 --- a/controllers/phish_test.go +++ b/controllers/phish_test.go @@ -5,22 +5,25 @@ import ( "encoding/json" "fmt" "io/ioutil" - "log" "net/http" "net/url" + "reflect" + "testing" "github.com/gophish/gophish/config" "github.com/gophish/gophish/models" ) -func (s *ControllersSuite) getFirstCampaign() models.Campaign { +func getFirstCampaign(t *testing.T) models.Campaign { campaigns, err := models.GetCampaigns(1) - s.Nil(err) + if err != nil { + t.Fatalf("error getting first campaign from database: %v", err) + } return campaigns[0] } -func (s *ControllersSuite) getFirstEmailRequest() models.EmailRequest { - campaign := s.getFirstCampaign() +func getFirstEmailRequest(t *testing.T) models.EmailRequest { + campaign := getFirstCampaign(t) req := models.EmailRequest{ TemplateId: campaign.TemplateId, Template: campaign.Template, @@ -33,205 +36,334 @@ func (s *ControllersSuite) getFirstEmailRequest() models.EmailRequest { FromAddress: campaign.SMTP.FromAddress, } err := models.PostEmailRequest(&req) - s.Nil(err) + if err != nil { + t.Fatalf("error creating email request: %v", err) + } return req } -func (s *ControllersSuite) openEmail(rid string) { - resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) - s.Nil(err) +func openEmail(t *testing.T, ctx *testContext, rid string) { + resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", ctx.phishServer.URL, models.RecipientParameter, rid)) + if err != nil { + t.Fatalf("error requesting /track endpoint: %v", err) + } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - s.Nil(err) + got, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("error reading response body from /track endpoint: %v", err) + } expected, err := ioutil.ReadFile("static/images/pixel.png") - s.Nil(err) - s.Equal(bytes.Compare(body, expected), 0) + if err != nil { + t.Fatalf("error reading local transparent pixel: %v", err) + } + if !bytes.Equal(got, expected) { + t.Fatalf("unexpected tracking pixel data received. expected %#v got %#v", expected, got) + } } -func (s *ControllersSuite) reportedEmail(rid string) { - resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) - s.Nil(err) - s.Equal(resp.StatusCode, http.StatusNoContent) -} - -func (s *ControllersSuite) reportEmail404(rid string) { - resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) - s.Nil(err) - s.Equal(resp.StatusCode, http.StatusNotFound) -} - -func (s *ControllersSuite) openEmail404(rid string) { - resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) - s.Nil(err) +func openEmail404(t *testing.T, ctx *testContext, rid string) { + resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", ctx.phishServer.URL, models.RecipientParameter, rid)) + if err != nil { + t.Fatalf("error requesting /track endpoint: %v", err) + } defer resp.Body.Close() - s.Nil(err) - s.Equal(resp.StatusCode, http.StatusNotFound) + got := resp.StatusCode + expected := http.StatusNotFound + if got != expected { + t.Fatalf("invalid status code received for /track endpoint. expected %d got %d", expected, got) + } } -func (s *ControllersSuite) clickLink(rid string, expectedHTML string) { - resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) - s.Nil(err) - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - s.Nil(err) - log.Printf("%s\n\n\n", body) - s.Equal(bytes.Compare(body, []byte(expectedHTML)), 0) +func reportedEmail(t *testing.T, ctx *testContext, rid string) { + resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", ctx.phishServer.URL, models.RecipientParameter, rid)) + if err != nil { + t.Fatalf("error requesting /report endpoint: %v", err) + } + got := resp.StatusCode + expected := http.StatusNoContent + if got != expected { + t.Fatalf("invalid status code received for /report endpoint. expected %d got %d", expected, got) + } } -func (s *ControllersSuite) clickLink404(rid string) { - resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) - s.Nil(err) - defer resp.Body.Close() - s.Nil(err) - s.Equal(resp.StatusCode, http.StatusNotFound) +func reportEmail404(t *testing.T, ctx *testContext, rid string) { + resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", ctx.phishServer.URL, models.RecipientParameter, rid)) + if err != nil { + t.Fatalf("error requesting /report endpoint: %v", err) + } + got := resp.StatusCode + expected := http.StatusNotFound + if got != expected { + t.Fatalf("invalid status code received for /report endpoint. expected %d got %d", expected, got) + } } -func (s *ControllersSuite) transparencyRequest(r models.Result, rid, path string) { - resp, err := http.Get(fmt.Sprintf("%s%s?%s=%s", s.phishServer.URL, path, models.RecipientParameter, rid)) - s.Nil(err) +func clickLink(t *testing.T, ctx *testContext, rid string, expectedHTML string) { + resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", ctx.phishServer.URL, models.RecipientParameter, rid)) + if err != nil { + t.Fatalf("error requesting / endpoint: %v", err) + } defer resp.Body.Close() - s.Equal(resp.StatusCode, http.StatusOK) + got, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("error reading payload from / endpoint response: %v", err) + } + if !bytes.Equal(got, []byte(expectedHTML)) { + t.Fatalf("invalid response received from / endpoint. expected %s got %s", got, expectedHTML) + } +} + +func clickLink404(t *testing.T, ctx *testContext, rid string) { + resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", ctx.phishServer.URL, models.RecipientParameter, rid)) + if err != nil { + t.Fatalf("error requesting / endpoint: %v", err) + } + defer resp.Body.Close() + got := resp.StatusCode + expected := http.StatusNotFound + if got != expected { + t.Fatalf("invalid status code received for / endpoint. expected %d got %d", expected, got) + } +} + +func transparencyRequest(t *testing.T, ctx *testContext, r models.Result, rid, path string) { + resp, err := http.Get(fmt.Sprintf("%s%s?%s=%s", ctx.phishServer.URL, path, models.RecipientParameter, rid)) + if err != nil { + t.Fatalf("error requesting %s endpoint: %v", path, err) + } + defer resp.Body.Close() + got := resp.StatusCode + expected := http.StatusOK + if got != expected { + t.Fatalf("invalid status code received for / endpoint. expected %d got %d", expected, got) + } tr := &TransparencyResponse{} err = json.NewDecoder(resp.Body).Decode(tr) - s.Nil(err) - s.Equal(tr.ContactAddress, s.config.ContactAddress) - s.Equal(tr.SendDate, r.SendDate) - s.Equal(tr.Server, config.ServerName) + if err != nil { + t.Fatalf("error unmarshaling transparency request: %v", err) + } + expectedResponse := &TransparencyResponse{ + ContactAddress: ctx.config.ContactAddress, + SendDate: r.SendDate, + Server: config.ServerName, + } + if !reflect.DeepEqual(tr, expectedResponse) { + t.Fatalf("unexpected transparency response received. expected %v got %v", expectedResponse, tr) + } } -func (s *ControllersSuite) TestOpenedPhishingEmail() { - campaign := s.getFirstCampaign() +func TestOpenedPhishingEmail(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) + campaign := getFirstCampaign(t) result := campaign.Results[0] - s.Equal(result.Status, models.StatusSending) + if result.Status != models.StatusSending { + t.Fatalf("unexpected result status received. expected %s got %s", models.StatusSending, result.Status) + } - s.openEmail(result.RId) + openEmail(t, ctx, result.RId) - campaign = s.getFirstCampaign() + campaign = getFirstCampaign(t) result = campaign.Results[0] lastEvent := campaign.Events[len(campaign.Events)-1] - s.Equal(result.Status, models.EventOpened) - s.Equal(lastEvent.Message, models.EventOpened) - s.Equal(result.ModifiedDate, lastEvent.Time) + if result.Status != models.EventOpened { + t.Fatalf("unexpected result status received. expected %s got %s", models.EventOpened, result.Status) + } + if lastEvent.Message != models.EventOpened { + t.Fatalf("unexpected event status received. expected %s got %s", lastEvent.Message, models.EventOpened) + } + if result.ModifiedDate != lastEvent.Time { + t.Fatalf("unexpected result modified date received. expected %s got %s", lastEvent.Time, result.ModifiedDate) + } } -func (s *ControllersSuite) TestReportedPhishingEmail() { - campaign := s.getFirstCampaign() +func TestReportedPhishingEmail(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) + campaign := getFirstCampaign(t) result := campaign.Results[0] - s.Equal(result.Status, models.StatusSending) + if result.Status != models.StatusSending { + t.Fatalf("unexpected result status received. expected %s got %s", models.StatusSending, result.Status) + } - s.reportedEmail(result.RId) + reportedEmail(t, ctx, result.RId) - campaign = s.getFirstCampaign() + campaign = getFirstCampaign(t) result = campaign.Results[0] lastEvent := campaign.Events[len(campaign.Events)-1] - s.Equal(result.Reported, true) - s.Equal(lastEvent.Message, models.EventReported) - s.Equal(result.ModifiedDate, lastEvent.Time) + + if result.Reported != true { + t.Fatalf("unexpected result report status received. expected %v got %v", true, result.Reported) + } + if lastEvent.Message != models.EventReported { + t.Fatalf("unexpected event status received. expected %s got %s", lastEvent.Message, models.EventReported) + } + if result.ModifiedDate != lastEvent.Time { + t.Fatalf("unexpected result modified date received. expected %s got %s", lastEvent.Time, result.ModifiedDate) + } } -func (s *ControllersSuite) TestClickedPhishingLinkAfterOpen() { - campaign := s.getFirstCampaign() +func TestClickedPhishingLinkAfterOpen(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) + campaign := getFirstCampaign(t) result := campaign.Results[0] - s.Equal(result.Status, models.StatusSending) + if result.Status != models.StatusSending { + t.Fatalf("unexpected result status received. expected %s got %s", models.StatusSending, result.Status) + } - s.openEmail(result.RId) - s.clickLink(result.RId, campaign.Page.HTML) + openEmail(t, ctx, result.RId) + clickLink(t, ctx, result.RId, campaign.Page.HTML) - campaign = s.getFirstCampaign() + campaign = getFirstCampaign(t) result = campaign.Results[0] lastEvent := campaign.Events[len(campaign.Events)-1] - s.Equal(result.Status, models.EventClicked) - s.Equal(lastEvent.Message, models.EventClicked) - s.Equal(result.ModifiedDate, lastEvent.Time) + if result.Status != models.EventClicked { + t.Fatalf("unexpected result status received. expected %s got %s", models.EventClicked, result.Status) + } + if lastEvent.Message != models.EventClicked { + t.Fatalf("unexpected event status received. expected %s got %s", lastEvent.Message, models.EventClicked) + } + if result.ModifiedDate != lastEvent.Time { + t.Fatalf("unexpected result modified date received. expected %s got %s", lastEvent.Time, result.ModifiedDate) + } } -func (s *ControllersSuite) TestNoRecipientID() { - resp, err := http.Get(fmt.Sprintf("%s/track", s.phishServer.URL)) - s.Nil(err) - s.Equal(resp.StatusCode, http.StatusNotFound) +func TestNoRecipientID(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) + resp, err := http.Get(fmt.Sprintf("%s/track", ctx.phishServer.URL)) + if err != nil { + t.Fatalf("error requesting /track endpoint: %v", err) + } + got := resp.StatusCode + expected := http.StatusNotFound + if got != expected { + t.Fatalf("invalid status code received for /track endpoint. expected %d got %d", expected, got) + } - resp, err = http.Get(s.phishServer.URL) - s.Nil(err) - s.Equal(resp.StatusCode, http.StatusNotFound) + resp, err = http.Get(ctx.phishServer.URL) + if err != nil { + t.Fatalf("error requesting /track endpoint: %v", err) + } + got = resp.StatusCode + if got != expected { + t.Fatalf("invalid status code received for / endpoint. expected %d got %d", expected, got) + } } -func (s *ControllersSuite) TestInvalidRecipientID() { +func TestInvalidRecipientID(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) rid := "XXXXXXXXXX" - s.openEmail404(rid) - s.clickLink404(rid) - s.reportEmail404(rid) + openEmail404(t, ctx, rid) + clickLink404(t, ctx, rid) + reportEmail404(t, ctx, rid) } -func (s *ControllersSuite) TestCompletedCampaignClick() { - campaign := s.getFirstCampaign() +func TestCompletedCampaignClick(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) + campaign := getFirstCampaign(t) result := campaign.Results[0] - s.Equal(result.Status, models.StatusSending) - s.openEmail(result.RId) + if result.Status != models.StatusSending { + t.Fatalf("unexpected result status received. expected %s got %s", models.StatusSending, result.Status) + } - campaign = s.getFirstCampaign() + openEmail(t, ctx, result.RId) + + campaign = getFirstCampaign(t) result = campaign.Results[0] - s.Equal(result.Status, models.EventOpened) + if result.Status != models.EventOpened { + t.Fatalf("unexpected result status received. expected %s got %s", models.EventOpened, result.Status) + } models.CompleteCampaign(campaign.Id, 1) - s.openEmail404(result.RId) - s.clickLink404(result.RId) + openEmail404(t, ctx, result.RId) + clickLink404(t, ctx, result.RId) - campaign = s.getFirstCampaign() + campaign = getFirstCampaign(t) result = campaign.Results[0] - s.Equal(result.Status, models.EventOpened) + if result.Status != models.EventOpened { + t.Fatalf("unexpected result status received. expected %s got %s", models.EventOpened, result.Status) + } } -func (s *ControllersSuite) TestRobotsHandler() { - expected := []byte("User-agent: *\nDisallow: /\n") - resp, err := http.Get(fmt.Sprintf("%s/robots.txt", s.phishServer.URL)) - s.Nil(err) - s.Equal(resp.StatusCode, http.StatusOK) +func TestRobotsHandler(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) + resp, err := http.Get(fmt.Sprintf("%s/robots.txt", ctx.phishServer.URL)) + if err != nil { + t.Fatalf("error requesting /robots.txt endpoint: %v", err) + } defer resp.Body.Close() + got := resp.StatusCode + expectedStatus := http.StatusOK + if got != expectedStatus { + t.Fatalf("invalid status code received for /track endpoint. expected %d got %d", expectedStatus, got) + } + expected := []byte("User-agent: *\nDisallow: /\n") body, err := ioutil.ReadAll(resp.Body) - s.Nil(err) - s.Equal(bytes.Compare(body, expected), 0) + if err != nil { + t.Fatalf("error reading response body from /robots.txt endpoint: %v", err) + } + if !bytes.Equal(body, expected) { + t.Fatalf("invalid robots.txt response received. expected %s got %s", expected, body) + } } -func (s *ControllersSuite) TestInvalidPreviewID() { +func TestInvalidPreviewID(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) bogusRId := fmt.Sprintf("%sbogus", models.PreviewPrefix) - s.openEmail404(bogusRId) - s.clickLink404(bogusRId) - s.reportEmail404(bogusRId) + openEmail404(t, ctx, bogusRId) + clickLink404(t, ctx, bogusRId) + reportEmail404(t, ctx, bogusRId) } -func (s *ControllersSuite) TestPreviewTrack() { - req := s.getFirstEmailRequest() - s.openEmail(req.RId) +func TestPreviewTrack(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) + req := getFirstEmailRequest(t) + openEmail(t, ctx, req.RId) } -func (s *ControllersSuite) TestPreviewClick() { - req := s.getFirstEmailRequest() - s.clickLink(req.RId, req.Page.HTML) +func TestPreviewClick(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) + req := getFirstEmailRequest(t) + clickLink(t, ctx, req.RId, req.Page.HTML) } -func (s *ControllersSuite) TestInvalidTransparencyRequest() { +func TestInvalidTransparencyRequest(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) bogusRId := fmt.Sprintf("bogus%s", TransparencySuffix) - s.openEmail404(bogusRId) - s.clickLink404(bogusRId) - s.reportEmail404(bogusRId) + openEmail404(t, ctx, bogusRId) + clickLink404(t, ctx, bogusRId) + reportEmail404(t, ctx, bogusRId) } -func (s *ControllersSuite) TestTransparencyRequest() { - campaign := s.getFirstCampaign() +func TestTransparencyRequest(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) + campaign := getFirstCampaign(t) result := campaign.Results[0] rid := fmt.Sprintf("%s%s", result.RId, TransparencySuffix) - s.transparencyRequest(result, rid, "/") - s.transparencyRequest(result, rid, "/track") - s.transparencyRequest(result, rid, "/report") + transparencyRequest(t, ctx, result, rid, "/") + transparencyRequest(t, ctx, result, rid, "/track") + transparencyRequest(t, ctx, result, rid, "/report") // And check with the URL encoded version of a + rid = fmt.Sprintf("%s%s", result.RId, "%2b") - s.transparencyRequest(result, rid, "/") - s.transparencyRequest(result, rid, "/track") - s.transparencyRequest(result, rid, "/report") + transparencyRequest(t, ctx, result, rid, "/") + transparencyRequest(t, ctx, result, rid, "/track") + transparencyRequest(t, ctx, result, rid, "/report") } -func (s *ControllersSuite) TestRedirectTemplating() { +func TestRedirectTemplating(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) p := models.Page{ Name: "Redirect Page", HTML: "Test", @@ -239,7 +371,9 @@ func (s *ControllersSuite) TestRedirectTemplating() { RedirectURL: "http://example.com/{{.RId}}", } err := models.PostPage(&p) - s.Nil(err) + if err != nil { + t.Fatalf("error posting new page: %v", err) + } smtp, _ := models.GetSMTP(1, 1) template, _ := models.GetTemplate(1, 1) group, _ := models.GetGroup(1, 1) @@ -251,7 +385,9 @@ func (s *ControllersSuite) TestRedirectTemplating() { campaign.SMTP = smtp campaign.Groups = []models.Group{group} err = models.PostCampaign(&campaign, campaign.UserId) - s.Nil(err) + if err != nil { + t.Fatalf("error creating campaign: %v", err) + } client := http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { @@ -259,12 +395,22 @@ func (s *ControllersSuite) TestRedirectTemplating() { }, } result := campaign.Results[0] - resp, err := client.PostForm(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, result.RId), url.Values{"username": {"test"}, "password": {"test"}}) - s.Nil(err) + resp, err := client.PostForm(fmt.Sprintf("%s/?%s=%s", ctx.phishServer.URL, models.RecipientParameter, result.RId), url.Values{"username": {"test"}, "password": {"test"}}) + if err != nil { + t.Fatalf("error requesting / endpoint: %v", err) + } defer resp.Body.Close() - s.Equal(http.StatusFound, resp.StatusCode) + got := resp.StatusCode + expectedStatus := http.StatusFound + if got != expectedStatus { + t.Fatalf("invalid status code received for /track endpoint. expected %d got %d", expectedStatus, got) + } expectedURL := fmt.Sprintf("http://example.com/%s", result.RId) - got, err := resp.Location() - s.Nil(err) - s.Equal(expectedURL, got.String()) + gotURL, err := resp.Location() + if err != nil { + t.Fatalf("error getting Location header from response: %v", err) + } + if gotURL.String() != expectedURL { + t.Fatalf("invalid redirect received. expected %s got %s", expectedURL, gotURL) + } } diff --git a/controllers/route_test.go b/controllers/route_test.go index c9c7975e..c3bd186d 100644 --- a/controllers/route_test.go +++ b/controllers/route_test.go @@ -5,106 +5,115 @@ import ( "net/http" "net/url" "strings" + "testing" "github.com/PuerkitoBio/goquery" ) -func (s *ControllersSuite) TestLoginCSRF() { - resp, err := http.PostForm(fmt.Sprintf("%s/login", s.adminServer.URL), +func attemptLogin(t *testing.T, ctx *testContext, client *http.Client, username, password, optionalPath string) *http.Response { + resp, err := http.Get(fmt.Sprintf("%s/login", ctx.adminServer.URL)) + if err != nil { + t.Fatalf("error requesting the /login endpoint: %v", err) + } + got := resp.StatusCode + expected := http.StatusOK + if got != expected { + t.Fatalf("invalid status code received. expected %d got %d", expected, got) + } + + doc, err := goquery.NewDocumentFromResponse(resp) + if err != nil { + t.Fatalf("error parsing /login response body") + } + elem := doc.Find("input[name='csrf_token']").First() + token, ok := elem.Attr("value") + if !ok { + t.Fatal("unable to find csrf_token value in login response") + } + if client == nil { + client = &http.Client{} + } + + req, err := http.NewRequest("POST", fmt.Sprintf("%s/login%s", ctx.adminServer.URL, optionalPath), strings.NewReader(url.Values{ + "username": {username}, + "password": {password}, + "csrf_token": {token}, + }.Encode())) + if err != nil { + t.Fatalf("error creating new /login request: %v", err) + } + + req.Header.Set("Cookie", resp.Header.Get("Set-Cookie")) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + resp, err = client.Do(req) + if err != nil { + t.Fatalf("error requesting the /login endpoint: %v", err) + } + return resp +} + +func TestLoginCSRF(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) + resp, err := http.PostForm(fmt.Sprintf("%s/login", ctx.adminServer.URL), url.Values{ "username": {"admin"}, "password": {"gophish"}, }) - s.Equal(resp.StatusCode, http.StatusForbidden) - fmt.Println(err) + if err != nil { + t.Fatalf("error requesting the /login endpoint: %v", err) + } + + got := resp.StatusCode + expected := http.StatusForbidden + if got != expected { + t.Fatalf("invalid status code received. expected %d got %d", expected, got) + } } -func (s *ControllersSuite) TestInvalidCredentials() { - resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL)) - s.Equal(err, nil) - s.Equal(resp.StatusCode, http.StatusOK) - - doc, err := goquery.NewDocumentFromResponse(resp) - s.Equal(err, nil) - elem := doc.Find("input[name='csrf_token']").First() - token, ok := elem.Attr("value") - s.Equal(ok, true) - - client := &http.Client{} - req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", s.adminServer.URL), strings.NewReader(url.Values{ - "username": {"admin"}, - "password": {"invalid"}, - "csrf_token": {token}, - }.Encode())) - s.Equal(err, nil) - - req.Header.Set("Cookie", resp.Header.Get("Set-Cookie")) - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - - resp, err = client.Do(req) - s.Equal(err, nil) - s.Equal(resp.StatusCode, http.StatusUnauthorized) +func TestInvalidCredentials(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) + resp := attemptLogin(t, ctx, nil, "admin", "bogus", "") + got := resp.StatusCode + expected := http.StatusUnauthorized + if got != expected { + t.Fatalf("invalid status code received. expected %d got %d", expected, got) + } } -func (s *ControllersSuite) TestSuccessfulLogin() { - resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL)) - s.Equal(err, nil) - s.Equal(resp.StatusCode, http.StatusOK) - - doc, err := goquery.NewDocumentFromResponse(resp) - s.Equal(err, nil) - elem := doc.Find("input[name='csrf_token']").First() - token, ok := elem.Attr("value") - s.Equal(ok, true) - - client := &http.Client{} - req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", s.adminServer.URL), strings.NewReader(url.Values{ - "username": {"admin"}, - "password": {"gophish"}, - "csrf_token": {token}, - }.Encode())) - s.Equal(err, nil) - - req.Header.Set("Cookie", resp.Header.Get("Set-Cookie")) - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - - resp, err = client.Do(req) - s.Equal(err, nil) - s.Equal(resp.StatusCode, http.StatusOK) +func TestSuccessfulLogin(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) + resp := attemptLogin(t, ctx, nil, "admin", "gophish", "") + got := resp.StatusCode + expected := http.StatusOK + if got != expected { + t.Fatalf("invalid status code received. expected %d got %d", expected, got) + } } -func (s *ControllersSuite) TestSuccessfulRedirect() { +func TestSuccessfulRedirect(t *testing.T) { + ctx := setupTest(t) + defer tearDown(t, ctx) next := "/campaigns" - resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL)) - s.Equal(err, nil) - s.Equal(resp.StatusCode, http.StatusOK) - - doc, err := goquery.NewDocumentFromResponse(resp) - s.Equal(err, nil) - elem := doc.Find("input[name='csrf_token']").First() - token, ok := elem.Attr("value") - s.Equal(ok, true) - client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse - }, + }} + resp := attemptLogin(t, ctx, client, "admin", "gophish", fmt.Sprintf("?next=%s", next)) + got := resp.StatusCode + expected := http.StatusFound + if got != expected { + t.Fatalf("invalid status code received. expected %d got %d", expected, got) } - req, err := http.NewRequest("POST", fmt.Sprintf("%s/login?next=%s", s.adminServer.URL, next), strings.NewReader(url.Values{ - "username": {"admin"}, - "password": {"gophish"}, - "csrf_token": {token}, - }.Encode())) - s.Equal(err, nil) - - req.Header.Set("Cookie", resp.Header.Get("Set-Cookie")) - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - - resp, err = client.Do(req) - s.Equal(err, nil) - s.Equal(resp.StatusCode, http.StatusFound) url, err := resp.Location() - s.Equal(err, nil) - s.Equal(url.Path, next) + if err != nil { + t.Fatalf("error parsing response Location header: %v", err) + } + if url.Path != next { + t.Fatalf("unexpected Location header received. expected %s got %s", next, url.Path) + } } diff --git a/mailer/mailer_test.go b/mailer/mailer_test.go index a8c0450f..0ae596e8 100644 --- a/mailer/mailer_test.go +++ b/mailer/mailer_test.go @@ -8,14 +8,8 @@ import ( "net/textproto" "reflect" "testing" - - "github.com/stretchr/testify/suite" ) -type MailerSuite struct { - suite.Suite -} - func generateMessages(dialer Dialer) []Mail { to := []string{"to@example.com"} @@ -47,30 +41,30 @@ func newMockErrorSender(err error) *mockSender { return sender } -func (ms *MailerSuite) TestDialHost() { +func TestDialHost(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() md := newMockDialer() md.setDial(md.unreachableDial) _, err := dialHost(ctx, md) if _, ok := err.(*ErrMaxConnectAttempts); !ok { - ms.T().Fatalf("Didn't receive expected ErrMaxConnectAttempts. Got: %s", err) + t.Fatalf("Didn't receive expected ErrMaxConnectAttempts. Got: %s", err) } e := err.(*ErrMaxConnectAttempts) if e.underlyingError != errHostUnreachable { - ms.T().Fatalf("Got invalid underlying error. Expected %s Got %s\n", e.underlyingError, errHostUnreachable) + t.Fatalf("Got invalid underlying error. Expected %s Got %s\n", e.underlyingError, errHostUnreachable) } if md.dialCount != MaxReconnectAttempts { - ms.T().Fatalf("Unexpected number of reconnect attempts. Expected %d, Got %d", MaxReconnectAttempts, md.dialCount) + t.Fatalf("Unexpected number of reconnect attempts. Expected %d, Got %d", MaxReconnectAttempts, md.dialCount) } md.setDial(md.defaultDial) _, err = dialHost(ctx, md) if err != nil { - ms.T().Fatalf("Unexpected error when dialing the mock host: %s", err) + t.Fatalf("Unexpected error when dialing the mock host: %s", err) } } -func (ms *MailerSuite) TestMailWorkerStart() { +func TestMailWorkerStart(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -97,16 +91,16 @@ func (ms *MailerSuite) TestMailWorkerStart() { got = append(got, message) original := messages[idx].(*mockMessage) if original.from != message.from { - ms.T().Fatalf("Invalid message received. Expected %s, Got %s", original.from, message.from) + t.Fatalf("Invalid message received. Expected %s, Got %s", original.from, message.from) } idx++ } if len(got) != len(messages) { - ms.T().Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), len(messages)) + t.Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), len(messages)) } } -func (ms *MailerSuite) TestBackoff() { +func TestBackoff(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -139,28 +133,28 @@ func (ms *MailerSuite) TestBackoff() { // Check that we only sent one message expectedCount := 1 if len(got) != expectedCount { - ms.T().Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount) + t.Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount) } // Check that it's the correct message originalFrom := messages[1].(*mockMessage).from if got[0].from != originalFrom { - ms.T().Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from) + t.Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from) } // Check that the first message performed a backoff backoffCount := messages[0].(*mockMessage).backoffCount if backoffCount != expectedCount { - ms.T().Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedCount) + t.Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedCount) } // Check that there was a reset performed on the sender if sender.resetCount != expectedCount { - ms.T().Fatalf("Did not receive expected reset. Got resetCount %d, expected %d", sender.resetCount, expectedCount) + t.Fatalf("Did not receive expected reset. Got resetCount %d, expected %d", sender.resetCount, expectedCount) } } -func (ms *MailerSuite) TestPermError() { +func TestPermError(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -193,13 +187,13 @@ func (ms *MailerSuite) TestPermError() { // Check that we only sent one message expectedCount := 1 if len(got) != expectedCount { - ms.T().Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount) + t.Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount) } // Check that it's the correct message originalFrom := messages[1].(*mockMessage).from if got[0].from != originalFrom { - ms.T().Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from) + t.Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from) } message := messages[0].(*mockMessage) @@ -208,21 +202,21 @@ func (ms *MailerSuite) TestPermError() { expectedBackoffCount := 0 backoffCount := message.backoffCount if backoffCount != expectedBackoffCount { - ms.T().Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedCount) + t.Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedCount) } // Check that there was a reset performed on the sender if sender.resetCount != expectedCount { - ms.T().Fatalf("Did not receive expected reset. Got resetCount %d, expected %d", sender.resetCount, expectedCount) + t.Fatalf("Did not receive expected reset. Got resetCount %d, expected %d", sender.resetCount, expectedCount) } // Check that the email errored out appropriately if !reflect.DeepEqual(message.err, expectedError) { - ms.T().Fatalf("Did not received expected error. Got %#v\nExpected %#v", message.err, expectedError) + t.Fatalf("Did not received expected error. Got %#v\nExpected %#v", message.err, expectedError) } } -func (ms *MailerSuite) TestUnknownError() { +func TestUnknownError(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -252,13 +246,13 @@ func (ms *MailerSuite) TestUnknownError() { // Check that we only sent one message expectedCount := 1 if len(got) != expectedCount { - ms.T().Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount) + t.Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount) } // Check that it's the correct message originalFrom := messages[1].(*mockMessage).from if got[0].from != originalFrom { - ms.T().Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from) + t.Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from) } message := messages[0].(*mockMessage) @@ -271,21 +265,17 @@ func (ms *MailerSuite) TestUnknownError() { expectedBackoffCount := 1 backoffCount := message.backoffCount if backoffCount != expectedBackoffCount { - ms.T().Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedBackoffCount) + t.Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedBackoffCount) } // Check that the underlying connection was reestablished expectedDialCount := 2 if dialer.dialCount != expectedDialCount { - ms.T().Fatalf("Did not receive expected dial count. Got %d expected %d", dialer.dialCount, expectedDialCount) + t.Fatalf("Did not receive expected dial count. Got %d expected %d", dialer.dialCount, expectedDialCount) } // Check that the email errored out appropriately if !reflect.DeepEqual(message.err, expectedError) { - ms.T().Fatalf("Did not received expected error. Got %#v\nExpected %#v", message.err, expectedError) + t.Fatalf("Did not received expected error. Got %#v\nExpected %#v", message.err, expectedError) } } - -func TestMailerSuite(t *testing.T) { - suite.Run(t, new(MailerSuite)) -} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 30752b01..e7621979 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -9,19 +9,17 @@ import ( "github.com/gophish/gophish/config" ctx "github.com/gophish/gophish/context" "github.com/gophish/gophish/models" - "github.com/stretchr/testify/suite" ) var successHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("success")) }) -type MiddlewareSuite struct { - suite.Suite +type testContext struct { apiKey string } -func (s *MiddlewareSuite) SetupSuite() { +func setupTest(t *testing.T) *testContext { conf := &config.Config{ DBName: "sqlite3", DBPath: ":memory:", @@ -29,12 +27,16 @@ func (s *MiddlewareSuite) SetupSuite() { } err := models.Setup(conf) if err != nil { - s.T().Fatalf("Failed creating database: %v", err) + t.Fatalf("Failed creating database: %v", err) } // Get the API key to use for these tests u, err := models.GetUser(1) - s.Nil(err) - s.apiKey = u.ApiKey + if err != nil { + t.Fatalf("error getting user: %v", err) + } + ctx := &testContext{} + ctx.apiKey = u.ApiKey + return ctx } // MiddlewarePermissionTest maps an expected HTTP Method to an expected HTTP @@ -43,7 +45,8 @@ type MiddlewarePermissionTest map[string]int // TestEnforceViewOnly ensures that only users with the ModifyObjects // permission have the ability to send non-GET requests. -func (s *MiddlewareSuite) TestEnforceViewOnly() { +func TestEnforceViewOnly(t *testing.T) { + setupTest(t) permissionTests := map[string]MiddlewarePermissionTest{ models.RoleAdmin: MiddlewarePermissionTest{ http.MethodGet: http.StatusOK, @@ -64,7 +67,9 @@ func (s *MiddlewareSuite) TestEnforceViewOnly() { } for r, checks := range permissionTests { role, err := models.GetRoleBySlug(r) - s.Nil(err) + if err != nil { + t.Fatalf("error getting role by slug: %v", err) + } for method, expected := range checks { req := httptest.NewRequest(method, "/", nil) @@ -76,12 +81,16 @@ func (s *MiddlewareSuite) TestEnforceViewOnly() { }) EnforceViewOnly(successHandler).ServeHTTP(response, req) - s.Equal(response.Code, expected) + got := response.Code + if got != expected { + t.Fatalf("incorrect status code received. expected %d got %d", expected, got) + } } } } -func (s *MiddlewareSuite) TestRequirePermission() { +func TestRequirePermission(t *testing.T) { + setupTest(t) middleware := RequirePermission(models.PermissionModifySystem) handler := middleware(successHandler) @@ -95,26 +104,37 @@ func (s *MiddlewareSuite) TestRequirePermission() { response := httptest.NewRecorder() // Test that with the requested permission, the request succeeds role, err := models.GetRoleBySlug(role) - s.Nil(err) + if err != nil { + t.Fatalf("error getting role by slug: %v", err) + } req = ctx.Set(req, "user", models.User{ Role: role, RoleID: role.ID, }) handler.ServeHTTP(response, req) - s.Equal(response.Code, expected) + got := response.Code + if got != expected { + t.Fatalf("incorrect status code received. expected %d got %d", expected, got) + } } } -func (s *MiddlewareSuite) TestRequireAPIKey() { +func TestRequireAPIKey(t *testing.T) { + setupTest(t) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Content-Type", "application/json") response := httptest.NewRecorder() // Test that making a request without an API key is denied RequireAPIKey(successHandler).ServeHTTP(response, req) - s.Equal(response.Code, http.StatusUnauthorized) + expected := http.StatusUnauthorized + got := response.Code + if got != expected { + t.Fatalf("incorrect status code received. expected %d got %d", expected, got) + } } -func (s *MiddlewareSuite) TestInvalidAPIKey() { +func TestInvalidAPIKey(t *testing.T) { + setupTest(t) req := httptest.NewRequest(http.MethodGet, "/", nil) query := req.URL.Query() query.Set("api_key", "bogus-api-key") @@ -122,18 +142,23 @@ func (s *MiddlewareSuite) TestInvalidAPIKey() { req.Header.Set("Content-Type", "application/json") response := httptest.NewRecorder() RequireAPIKey(successHandler).ServeHTTP(response, req) - s.Equal(response.Code, http.StatusUnauthorized) + expected := http.StatusUnauthorized + got := response.Code + if got != expected { + t.Fatalf("incorrect status code received. expected %d got %d", expected, got) + } } -func (s *MiddlewareSuite) TestBearerToken() { +func TestBearerToken(t *testing.T) { + testCtx := setupTest(t) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.apiKey)) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", testCtx.apiKey)) req.Header.Set("Content-Type", "application/json") response := httptest.NewRecorder() RequireAPIKey(successHandler).ServeHTTP(response, req) - s.Equal(response.Code, http.StatusOK) -} - -func TestMiddlewareSuite(t *testing.T) { - suite.Run(t, new(MiddlewareSuite)) + expected := http.StatusOK + got := response.Code + if got != expected { + t.Fatalf("incorrect status code received. expected %d got %d", expected, got) + } } diff --git a/models/models.go b/models/models.go index 3384aad5..098973e2 100644 --- a/models/models.go +++ b/models/models.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "io/ioutil" + "path/filepath" "time" "bitbucket.org/liamstask/goose/lib/goose" @@ -93,6 +94,8 @@ func Setup(c *config.Config) error { Env: "production", Driver: chooseDBDriver(conf.DBName, conf.DBPath), } + abs, _ := filepath.Abs(migrateConf.MigrationsDir) + fmt.Println(abs) // Get the latest possible migration latest, err := goose.GetMostRecentDBVersion(migrateConf.MigrationsDir) if err != nil { diff --git a/util/util_test.go b/util/util_test.go index 874e3cbf..83c2a78f 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -8,28 +8,9 @@ import ( "reflect" "testing" - "github.com/gophish/gophish/config" "github.com/gophish/gophish/models" - "github.com/stretchr/testify/suite" ) -type UtilSuite struct { - suite.Suite -} - -func (s *UtilSuite) SetupSuite() { - conf := &config.Config{ - DBName: "sqlite3", - DBPath: ":memory:", - MigrationsPath: "../db/db_sqlite3/migrations/", - } - err := models.Setup(conf) - if err != nil { - s.T().Fatalf("Failed creating database: %v", err) - } - s.Nil(err) -} - func buildCSVRequest(csvPayload string) (*http.Request, error) { csvHeader := "First Name,Last Name,Email\n" body := new(bytes.Buffer) @@ -52,7 +33,7 @@ func buildCSVRequest(csvPayload string) (*http.Request, error) { return r, nil } -func (s *UtilSuite) TestParseCSVEmail() { +func TestParseCSVEmail(t *testing.T) { expected := models.Target{ BaseRecipient: models.BaseRecipient{ FirstName: "John", @@ -63,16 +44,19 @@ func (s *UtilSuite) TestParseCSVEmail() { csvPayload := fmt.Sprintf("%s,%s,<%s>", expected.FirstName, expected.LastName, expected.Email) r, err := buildCSVRequest(csvPayload) - s.Nil(err) + if err != nil { + t.Fatalf("error building CSV request: %v", err) + } got, err := ParseCSV(r) - s.Nil(err) - s.Equal(len(got), 1) + if err != nil { + t.Fatalf("error parsing CSV: %v", err) + } + expectedLength := 1 + if len(got) != expectedLength { + t.Fatalf("invalid number of results received from CSV. expected %d got %d", expectedLength, len(got)) + } if !reflect.DeepEqual(expected, got[0]) { - s.T().Fatalf("Incorrect targets received. Expected: %#v\nGot: %#v", expected, got) + t.Fatalf("Incorrect targets received. Expected: %#v\nGot: %#v", expected, got) } } - -func TestUtilSuite(t *testing.T) { - suite.Run(t, new(UtilSuite)) -} diff --git a/webhook/webhook_test.go b/webhook/webhook_test.go index 8cf0dd12..6fbb07b2 100644 --- a/webhook/webhook_test.go +++ b/webhook/webhook_test.go @@ -7,16 +7,10 @@ import ( "log" "net/http" "net/http/httptest" + "reflect" "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" ) -type WebhookSuite struct { - suite.Suite -} - type mockSender struct { client *http.Client } @@ -33,22 +27,24 @@ func (ms mockSender) Send(endPoint EndPoint, data interface{}) error { return nil } -func (s *WebhookSuite) TestSendMocked() { - mcSnd := newMockSender() - endp1 := EndPoint{URL: "http://example.com/a1", Secret: "s1"} - d1 := map[string]string{ +func TestSendMocked(t *testing.T) { + ms := newMockSender() + endpoint := EndPoint{URL: "http://example.com/a1", Secret: "s1"} + data := map[string]string{ "a1": "a11", "a2": "a22", "a3": "a33", } - err := mcSnd.Send(endp1, d1) - s.Nil(err) + err := ms.Send(endpoint, data) + if err != nil { + t.Fatalf("error sending data to webhook endpoint: %v", err) + } } -func (s *WebhookSuite) TestSendReal() { - expectedSign := "004b36ca3fcbc01a08b17bf5d4a7e1aa0b10e14f55f3f8bd9acac0c7e8d2635d" +func TestSendReal(t *testing.T) { + expectedSig := "004b36ca3fcbc01a08b17bf5d4a7e1aa0b10e14f55f3f8bd9acac0c7e8d2635d" secret := "secret456" - d1 := map[string]interface{}{ + data := map[string]interface{}{ "key1": "val1", "key2": "val2", "key3": "val3", @@ -58,37 +54,50 @@ func (s *WebhookSuite) TestSendReal() { fmt.Println("[test] running the server...") signStartIdx := len(Sha256Prefix) + 1 - realSignRaw := r.Header.Get(SignatureHeader) - realSign := realSignRaw[signStartIdx:] - assert.Equal(s.T(), expectedSign, realSign) + sigHeader := r.Header.Get(SignatureHeader) + gotSig := sigHeader[signStartIdx:] + if expectedSig != gotSig { + t.Fatalf("invalid signature received. expected %s got %s", expectedSig, gotSig) + } - contTypeJsonHeader := r.Header.Get("Content-Type") - assert.Equal(s.T(), contTypeJsonHeader, "application/json") + ct := r.Header.Get("Content-Type") + expectedCT := "application/json" + if ct != expectedCT { + t.Fatalf("invalid content type. expected %s got %s", ct, expectedCT) + } body, err := ioutil.ReadAll(r.Body) - s.Nil(err) + if err != nil { + t.Fatalf("error reading JSON body from webhook request: %v", err) + } - var d2 map[string]interface{} - err = json.Unmarshal(body, &d2) - s.Nil(err) - assert.Equal(s.T(), d1, d2) + var payload map[string]interface{} + err = json.Unmarshal(body, &payload) + if err != nil { + t.Fatalf("error unmarshaling webhook payload: %v", err) + } + if !reflect.DeepEqual(data, payload) { + t.Fatalf("invalid payload received. expected %#v got %#v", data, payload) + } })) defer ts.Close() endp1 := EndPoint{URL: ts.URL, Secret: secret} - err := Send(endp1, d1) - s.Nil(err) + err := Send(endp1, data) + if err != nil { + t.Fatalf("error sending data to webhook endpoint: %v", err) + } } -func (s *WebhookSuite) TestSignature() { +func TestSignature(t *testing.T) { secret := "secret123" payload := []byte("some payload456") - expectedSign := "ab7844c1e9149f8dc976c4188a72163c005930f3c2266a163ffe434230bdf761" - realSign, err := sign(secret, payload) - s.Nil(err) - assert.Equal(s.T(), expectedSign, realSign) -} - -func TestWebhookSuite(t *testing.T) { - suite.Run(t, new(WebhookSuite)) + expected := "ab7844c1e9149f8dc976c4188a72163c005930f3c2266a163ffe434230bdf761" + got, err := sign(secret, payload) + if err != nil { + t.Fatalf("error signing payload: %v", err) + } + if expected != got { + t.Fatalf("invalid signature received. expected %s got %s", expected, got) + } } diff --git a/worker/worker_test.go b/worker/worker_test.go index eecd59ec..51783dda 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -1,18 +1,18 @@ package worker import ( + "testing" + "github.com/gophish/gophish/config" "github.com/gophish/gophish/models" - "github.com/stretchr/testify/suite" ) -// WorkerSuite is a suite of tests to cover API related functions -type WorkerSuite struct { - suite.Suite +// testContext is context to cover API related functions +type testContext struct { config *config.Config } -func (s *WorkerSuite) SetupSuite() { +func setupTest(t *testing.T) *testContext { conf := &config.Config{ DBName: "sqlite3", DBPath: ":memory:", @@ -20,21 +20,15 @@ func (s *WorkerSuite) SetupSuite() { } err := models.Setup(conf) if err != nil { - s.T().Fatalf("Failed creating database: %v", err) + t.Fatalf("Failed creating database: %v", err) } - s.config = conf - s.Nil(err) + ctx := &testContext{} + ctx.config = conf + return ctx } -func (s *WorkerSuite) TearDownTest() { - campaigns, _ := models.GetCampaigns(1) - for _, campaign := range campaigns { - models.DeleteCampaign(campaign.Id) - } -} - -func (s *WorkerSuite) SetupTest() { - s.config.TestFlag = true +func createTestData(t *testing.T, ctx *testContext) { + ctx.config.TestFlag = true // Add a group group := models.Group{Name: "Test Group"} group.Targets = []models.Target{ @@ -45,12 +39,12 @@ func (s *WorkerSuite) SetupTest() { models.PostGroup(&group) // Add a template - t := models.Template{Name: "Test Template"} - t.Subject = "Test subject" - t.Text = "Text text" - t.HTML = "Test" - t.UserId = 1 - models.PostTemplate(&t) + template := models.Template{Name: "Test Template"} + template.Subject = "Test subject" + template.Text = "Text text" + template.HTML = "Test" + template.UserId = 1 + models.PostTemplate(&template) // Add a landing page p := models.Page{Name: "Test Page"} @@ -69,7 +63,7 @@ func (s *WorkerSuite) SetupTest() { // Set the status such that no emails are attempted c := models.Campaign{Name: "Test campaign"} c.UserId = 1 - c.Template = t + c.Template = template c.Page = p c.SMTP = smtp c.Groups = []models.Group{group}