diff --git a/go.mod b/go.mod index 9984fcb4..f763b49d 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect github.com/go-sql-driver/mysql v1.5.0 github.com/gophish/gomail v0.0.0-20180314010319-cf7e1a5479be + github.com/gorilla/context v1.1.1 github.com/gorilla/csrf v1.6.2 github.com/gorilla/handlers v1.4.2 github.com/gorilla/mux v1.7.3 @@ -23,8 +24,10 @@ require ( github.com/mxk/go-imap v0.0.0-20150429134902-531c36c3f12d github.com/oschwald/maxminddb-golang v1.6.0 github.com/sirupsen/logrus v1.4.2 - github.com/stretchr/testify v1.4.0 github.com/ziutek/mymysql v1.5.4 // indirect golang.org/x/crypto v0.0.0-20200128174031-69ecbb4d6d5d gopkg.in/alecthomas/kingpin.v2 v2.2.6 + gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect + gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 + gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df // indirect ) diff --git a/go.sum b/go.sum index 32632374..c9f467b7 100644 --- a/go.sum +++ b/go.sum @@ -13,15 +13,20 @@ github.com/andybalholm/cascadia v1.0.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9Pq github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd h1:83Wprp6ROGeiHFAP8WJdI2RoxALQYgdllERc3N5N2DM= github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DPaBjB8zlTR87/ElzFsnQfuHnVUVqpZZIcV5Y= github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/gophish/gomail v0.0.0-20180314010319-cf7e1a5479be h1:VTe1cdyqSi/wLowKNz/shz6E0G+9/XzldZbyAmt+0Yw= github.com/gophish/gomail v0.0.0-20180314010319-cf7e1a5479be/go.mod h1:MpSuP7kw+gRy2z+4gIFZeF3DwhhdQhEXwRmPVQYD9ig= +github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= +github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/csrf v1.6.2 h1:QqQ/OWwuFp4jMKgBFAzJVW3FMULdyUW7JoM4pEWuqKg= github.com/gorilla/csrf v1.6.2/go.mod h1:7tSf8kmjNYr7IWDCYhd3U8Ck34iQ/Yw5CJu7bAkHEGI= github.com/gorilla/handlers v1.4.2 h1:0QniY0USkHQ1RGCLfKxeNHK9bkDHGRYGNDFBCS+YARg= @@ -36,11 +41,13 @@ github.com/jinzhu/gorm v1.9.12 h1:Drgk1clyWT9t9ERbzHza6Mj/8FY/CqMyVzOiHviMo6Q= github.com/jinzhu/gorm v1.9.12/go.mod h1:vhTjlKSJUTWNtcbQtrMBFCxy7eXTzeCAzfL5fBZT/Qs= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M= github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jordan-wright/email v0.0.0-20200121133829-a0b5c5b58bb6 h1:gI29NnCaNU8N7rZT2svjtas5SrbL0XsutOPtInVvGIA= github.com/jordan-wright/email v0.0.0-20200121133829-a0b5c5b58bb6/go.mod h1:1c7szIrayyPPB/987hsnvNzLushdWf4o/79s3P08L8A= github.com/jordan-wright/unindexed v0.0.0-20181209214434-78fa79113c0f h1:bYVTBvVHcAYDkH8hyVMRUW7J2mYQNNSmQPXGadYd1nY= github.com/jordan-wright/unindexed v0.0.0-20181209214434-78fa79113c0f/go.mod h1:eRt05O5haIXGKGodWjpQ2xdgBHTE7hg/pzsukNi9IRA= +github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kylelemons/go-gypsy v0.0.0-20160905020020-08cad365cd28 h1:mkl3tvPHIuPaWsLtmHTybJeoVEW7cbePK73Ir8VtruA= github.com/kylelemons/go-gypsy v0.0.0-20160905020020-08cad365cd28/go.mod h1:T/T7jsxVqf9k/zYOqbgNAsANsjxTd1Yq3htjDhQ1H0c= @@ -86,6 +93,11 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= +gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE= +gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/models/campaign.go b/models/campaign.go index 8ab98d56..90884e14 100644 --- a/models/campaign.go +++ b/models/campaign.go @@ -155,8 +155,8 @@ func (c *Campaign) UpdateStatus(s string) error { } // AddEvent creates a new campaign event in the database -func (c *Campaign) AddEvent(e *Event) error { - e.CampaignId = c.Id +func AddEvent(e *Event, campaignID int64) error { + e.CampaignId = campaignID e.Time = time.Now().UTC() whs, err := GetActiveWebhooks() @@ -362,6 +362,38 @@ func GetCampaignSummary(id int64, uid int64) (CampaignSummary, error) { return cs, nil } +// GetCampaignMailContext returns a campaign object with just the relevant +// data needed to generate and send emails. This includes the top-level +// metadata, the template, and the sending profile. +// +// This should only ever be used if you specifically want this lightweight +// context, since it returns a non-standard campaign object. +// ref: #1726 +func GetCampaignMailContext(id int64, uid int64) (Campaign, error) { + c := Campaign{} + err := db.Where("id = ?", id).Where("user_id = ?", uid).Find(&c).Error + if err != nil { + return c, err + } + err = db.Table("smtp").Where("id=?", c.SMTPId).Find(&c.SMTP).Error + if err != nil { + return c, err + } + err = db.Where("smtp_id=?", c.SMTP.Id).Find(&c.SMTP.Headers).Error + if err != nil && err != gorm.ErrRecordNotFound { + return c, err + } + err = db.Table("templates").Where("id=?", c.TemplateId).Find(&c.Template).Error + if err != nil { + return c, err + } + err = db.Where("template_id=?", c.Template.Id).Find(&c.Template.Attachments).Error + if err != nil && err != gorm.ErrRecordNotFound { + return c, err + } + return c, nil +} + // GetCampaign returns the campaign, if it exists, specified by the given id and user_id. func GetCampaign(id int64, uid int64) (Campaign, error) { c := Campaign{} @@ -500,7 +532,7 @@ func PostCampaign(c *Campaign, uid int64) error { log.Error(err) return err } - err = c.AddEvent(&Event{Message: "Campaign Created"}) + err = AddEvent(&Event{Message: "Campaign Created"}, c.Id) if err != nil { log.Error(err) } diff --git a/models/campaign_test.go b/models/campaign_test.go index e1ca9b9f..6491987e 100644 --- a/models/campaign_test.go +++ b/models/campaign_test.go @@ -283,3 +283,55 @@ func BenchmarkCampaign10000(b *testing.B) { } tearDownBenchmark(b) } + +func BenchmarkGetCampaign100(b *testing.B) { + setupBenchmark(b) + campaign := setupCampaign(b, 100) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := GetCampaign(campaign.Id, campaign.UserId) + if err != nil { + b.Fatalf("error getting campaign: %v", err) + } + } + tearDownBenchmark(b) +} + +func BenchmarkGetCampaign1000(b *testing.B) { + setupBenchmark(b) + campaign := setupCampaign(b, 1000) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := GetCampaign(campaign.Id, campaign.UserId) + if err != nil { + b.Fatalf("error getting campaign: %v", err) + } + } + tearDownBenchmark(b) +} + +func BenchmarkGetCampaign5000(b *testing.B) { + setupBenchmark(b) + campaign := setupCampaign(b, 5000) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := GetCampaign(campaign.Id, campaign.UserId) + if err != nil { + b.Fatalf("error getting campaign: %v", err) + } + } + tearDownBenchmark(b) +} + +func BenchmarkGetCampaign10000(b *testing.B) { + setupBenchmark(b) + campaign := setupCampaign(b, 10000) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := GetCampaign(campaign.Id, campaign.UserId) + if err != nil { + b.Fatalf("error getting campaign: %v", err) + } + } + tearDownBenchmark(b) +} diff --git a/models/maillog.go b/models/maillog.go index 59a7f639..082ddceb 100644 --- a/models/maillog.go +++ b/models/maillog.go @@ -37,6 +37,8 @@ type MailLog struct { SendDate time.Time `json:"send_date"` SendAttempt int `json:"send_attempt"` Processing bool `json:"-"` + + cachedCampaign *Campaign } // GenerateMailLog creates a new maillog for the given campaign and @@ -128,13 +130,27 @@ func (m *MailLog) Success() error { // GetDialer returns a dialer based on the maillog campaign's SMTP configuration func (m *MailLog) GetDialer() (mailer.Dialer, error) { - c, err := GetCampaign(m.CampaignId, m.UserId) - if err != nil { - return nil, err + c := m.cachedCampaign + if c == nil { + campaign, err := GetCampaignMailContext(m.CampaignId, m.UserId) + if err != nil { + return nil, err + } + c = &campaign } return c.SMTP.GetDialer() } +// CacheCampaign allows bulk-mail workers to cache the otherwise expensive +// campaign lookup operation by providing a pointer to the campaign here. +func (m *MailLog) CacheCampaign(campaign *Campaign) error { + if campaign.Id != m.CampaignId { + return fmt.Errorf("incorrect campaign provided for caching. expected %d got %d", m.CampaignId, campaign.Id) + } + m.cachedCampaign = campaign + return nil +} + // Generate fills in the details of a gomail.Message instance with // the correct headers and body from the campaign and recipient listed in // the maillog. We accept the gomail.Message as an argument so that the caller @@ -144,9 +160,13 @@ func (m *MailLog) Generate(msg *gomail.Message) error { if err != nil { return err } - c, err := GetCampaign(m.CampaignId, m.UserId) - if err != nil { - return err + c := m.cachedCampaign + if c == nil { + campaign, err := GetCampaignMailContext(m.CampaignId, m.UserId) + if err != nil { + return err + } + c = &campaign } f, err := mail.ParseAddress(c.SMTP.FromAddress) @@ -155,7 +175,7 @@ func (m *MailLog) Generate(msg *gomail.Message) error { } msg.SetAddressHeader("From", f.Address, f.Name) - ptx, err := NewPhishingTemplateContext(&c, r.BaseRecipient, r.RId) + ptx, err := NewPhishingTemplateContext(c, r.BaseRecipient, r.RId) if err != nil { return err } diff --git a/models/maillog_test.go b/models/maillog_test.go index f90b8512..0c8c757c 100644 --- a/models/maillog_test.go +++ b/models/maillog_test.go @@ -331,6 +331,7 @@ func BenchmarkMailLogGenerate100(b *testing.B) { if err != nil { b.Fatalf("error getting maillogs for campaign: %v", err) } + ms[0].CacheCampaign(&campaign) b.ResetTimer() for i := 0; i < b.N; i++ { msg := gomail.NewMessage() @@ -346,6 +347,7 @@ func BenchmarkMailLogGenerate1000(b *testing.B) { if err != nil { b.Fatalf("error getting maillogs for campaign: %v", err) } + ms[0].CacheCampaign(&campaign) b.ResetTimer() for i := 0; i < b.N; i++ { msg := gomail.NewMessage() @@ -361,6 +363,7 @@ func BenchmarkMailLogGenerate5000(b *testing.B) { if err != nil { b.Fatalf("error getting maillogs for campaign: %v", err) } + ms[0].CacheCampaign(&campaign) b.ResetTimer() for i := 0; i < b.N; i++ { msg := gomail.NewMessage() @@ -376,6 +379,7 @@ func BenchmarkMailLogGenerate10000(b *testing.B) { if err != nil { b.Fatalf("error getting maillogs for campaign: %v", err) } + ms[0].CacheCampaign(&campaign) b.ResetTimer() for i := 0; i < b.N; i++ { msg := gomail.NewMessage() diff --git a/models/result.go b/models/result.go index 7b071ea4..6ad5812f 100644 --- a/models/result.go +++ b/models/result.go @@ -39,10 +39,6 @@ type Result struct { } func (r *Result) createEvent(status string, details interface{}) (*Event, error) { - c, err := GetCampaign(r.CampaignId, r.UserId) - if err != nil { - return nil, err - } e := &Event{Email: r.Email, Message: status} if details != nil { dj, err := json.Marshal(details) @@ -51,7 +47,7 @@ func (r *Result) createEvent(status string, details interface{}) (*Event, error) } e.Details = string(dj) } - c.AddEvent(e) + AddEvent(e, r.CampaignId) return e, nil } diff --git a/worker/worker.go b/worker/worker.go index 1a93128e..c4f2f1f0 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -45,55 +45,71 @@ func WithMailer(m mailer.Mailer) func(*DefaultWorker) error { } } +// processCampaigns loads maillogs scheduled to be sent before the provided +// time and sends them to the mailer. +func (w *DefaultWorker) processCampaigns(t time.Time) error { + ms, err := models.GetQueuedMailLogs(t.UTC()) + if err != nil { + log.Error(err) + return err + } + // Lock the MailLogs (they will be unlocked after processing) + err = models.LockMailLogs(ms, true) + if err != nil { + return err + } + campaignCache := make(map[int64]models.Campaign) + // We'll group the maillogs by campaign ID to (roughly) group + // them by sending profile. This lets the mailer re-use the Sender + // instead of having to re-connect to the SMTP server for every + // email. + msg := make(map[int64][]mailer.Mail) + for _, m := range ms { + // We cache the campaign here to greatly reduce the time it takes to + // generate the message (ref #1726) + c, ok := campaignCache[m.CampaignId] + if !ok { + c, err = models.GetCampaignMailContext(m.CampaignId, m.UserId) + if err != nil { + return err + } + campaignCache[c.Id] = c + } + m.CacheCampaign(&c) + msg[m.CampaignId] = append(msg[m.CampaignId], m) + } + + // Next, we process each group of maillogs in parallel + for cid, msc := range msg { + go func(cid int64, msc []mailer.Mail) { + c := campaignCache[cid] + if c.Status == models.CampaignQueued { + err := c.UpdateStatus(models.CampaignInProgress) + if err != nil { + log.Error(err) + return + } + } + log.WithFields(logrus.Fields{ + "num_emails": len(msc), + }).Info("Sending emails to mailer for processing") + w.mailer.Queue(msc) + }(cid, msc) + } + return nil +} + // Start launches the worker to poll the database every minute for any pending maillogs // that need to be processed. func (w *DefaultWorker) Start() { log.Info("Background Worker Started Successfully - Waiting for Campaigns") go w.mailer.Start(context.Background()) for t := range time.Tick(1 * time.Minute) { - ms, err := models.GetQueuedMailLogs(t.UTC()) + err := w.processCampaigns(t) if err != nil { log.Error(err) continue } - // Lock the MailLogs (they will be unlocked after processing) - err = models.LockMailLogs(ms, true) - if err != nil { - log.Error(err) - continue - } - // We'll group the maillogs by campaign ID to (sort of) group - // them by sending profile. This lets the mailer re-use the Sender - // instead of having to re-connect to the SMTP server for every - // email. - msg := make(map[int64][]mailer.Mail) - for _, m := range ms { - msg[m.CampaignId] = append(msg[m.CampaignId], m) - } - - // Next, we process each group of maillogs in parallel - for cid, msc := range msg { - go func(cid int64, msc []mailer.Mail) { - uid := msc[0].(*models.MailLog).UserId - c, err := models.GetCampaign(cid, uid) - if err != nil { - log.Error(err) - errorMail(err, msc) - return - } - if c.Status == models.CampaignQueued { - err := c.UpdateStatus(models.CampaignInProgress) - if err != nil { - log.Error(err) - return - } - } - log.WithFields(logrus.Fields{ - "num_emails": len(msc), - }).Info("Sending emails to mailer for processing") - w.mailer.Queue(msc) - }(cid, msc) - } } } @@ -116,6 +132,11 @@ func (w *DefaultWorker) LaunchCampaign(c models.Campaign) { m.Unlock() continue } + err = m.CacheCampaign(&c) + if err != nil { + log.Error(err) + return + } mailEntries = append(mailEntries, m) } w.mailer.Queue(mailEntries) diff --git a/worker/worker_test.go b/worker/worker_test.go index 51783dda..44c6eea1 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -1,12 +1,28 @@ package worker import ( + "context" + "fmt" "testing" + "time" "github.com/gophish/gophish/config" + "github.com/gophish/gophish/mailer" "github.com/gophish/gophish/models" ) +type logMailer struct { + queue chan []mailer.Mail +} + +func (m *logMailer) Start(ctx context.Context) { + return +} + +func (m *logMailer) Queue(ms []mailer.Mail) { + m.queue <- ms +} + // testContext is context to cover API related functions type testContext struct { config *config.Config @@ -24,6 +40,7 @@ func setupTest(t *testing.T) *testContext { } ctx := &testContext{} ctx.config = conf + createTestData(t, ctx) return ctx } @@ -31,9 +48,12 @@ func createTestData(t *testing.T, ctx *testContext) { ctx.config.TestFlag = true // Add a group group := models.Group{Name: "Test Group"} - group.Targets = []models.Target{ - models.Target{BaseRecipient: models.BaseRecipient{Email: "test1@example.com", FirstName: "First", LastName: "Example"}}, - models.Target{BaseRecipient: models.BaseRecipient{Email: "test2@example.com", FirstName: "Second", LastName: "Example"}}, + for i := 0; i < 10; i++ { + group.Targets = append(group.Targets, models.Target{ + BaseRecipient: models.BaseRecipient{ + Email: fmt.Sprintf("test%d@example.com", i), + FirstName: "First", + LastName: "Example"}}) } group.UserId = 1 models.PostGroup(&group) @@ -58,15 +78,88 @@ func createTestData(t *testing.T, ctx *testContext) { smtp.Host = "example.com" smtp.FromAddress = "test@test.com" models.PostSMTP(&smtp) +} +func setupCampaign(id int) (*models.Campaign, error) { // Setup and "launch" our campaign // Set the status such that no emails are attempted - c := models.Campaign{Name: "Test campaign"} + c := models.Campaign{Name: fmt.Sprintf("Test campaign - %d", id)} c.UserId = 1 + template, err := models.GetTemplate(1, 1) + if err != nil { + return nil, err + } c.Template = template - c.Page = p + + page, err := models.GetPage(1, 1) + if err != nil { + return nil, err + } + c.Page = page + + smtp, err := models.GetSMTP(1, 1) + if err != nil { + return nil, err + } c.SMTP = smtp + + group, err := models.GetGroup(1, 1) + if err != nil { + return nil, err + } c.Groups = []models.Group{group} - models.PostCampaign(&c, c.UserId) - c.UpdateStatus(models.CampaignEmailsSent) + err = models.PostCampaign(&c, c.UserId) + if err != nil { + return nil, err + } + err = c.UpdateStatus(models.CampaignEmailsSent) + return &c, err +} + +func TestMailLogGrouping(t *testing.T) { + setupTest(t) + + // Create the campaigns and unlock the maillogs so that they're picked up + // by the worker + for i := 0; i < 10; i++ { + campaign, err := setupCampaign(i) + if err != nil { + t.Fatalf("error creating campaign: %v", err) + } + ms, err := models.GetMailLogsByCampaign(campaign.Id) + if err != nil { + t.Fatalf("error getting maillogs for campaign: %v", err) + } + for _, m := range ms { + m.Unlock() + } + } + + lm := &logMailer{queue: make(chan []mailer.Mail)} + worker := &DefaultWorker{} + worker.mailer = lm + + // Trigger the worker, generating the maillogs and sending them to the + // mailer + worker.processCampaigns(time.Now()) + + // Verify that each slice of maillogs received belong to the same campaign + for i := 0; i < 10; i++ { + ms := <-lm.queue + maillog, ok := ms[0].(*models.MailLog) + if !ok { + t.Fatalf("unable to cast mail to models.MailLog") + } + expected := maillog.CampaignId + for _, m := range ms { + maillog, ok = m.(*models.MailLog) + if !ok { + t.Fatalf("unable to cast mail to models.MailLog") + } + got := maillog.CampaignId + if got != expected { + t.Fatalf("unexpected campaign ID received for maillog: got %d expected %d", got, expected) + } + } + } }