From 47f0049c30c0ec1e581e82934450287842d910bc Mon Sep 17 00:00:00 2001 From: Jordan Wright Date: Sat, 15 Dec 2018 15:42:32 -0600 Subject: [PATCH] Refactor servers (#1321) * Refactoring servers to support custom workers and graceful shutdown. * Refactoring workers to support custom mailers. * Refactoring mailer to be an interface, with proper instances instead of a single global instance * Cleaning up a few things. Locking maillogs for campaigns set to launch immediately to prevent a race condition. * Cleaning up API middleware to be simpler * Moving template parameters to separate struct * Changed LoadConfig to return config object * Cleaned up some error handling, removing uninitialized global error in models package * Changed static file serving to use the unindexed package --- config/config.go | 18 +-- config/config_test.go | 8 +- controllers/api.go | 55 +++---- controllers/api_test.go | 49 +++--- controllers/phish.go | 124 ++++++++++++--- controllers/phish_test.go | 24 +-- controllers/route.go | 299 +++++++++++++++++++---------------- controllers/route_test.go | 14 +- controllers/static.go | 35 ---- controllers/static_test.go | 81 ---------- gophish.go | 81 ++++------ logger/logger.go | 4 +- mailer/mailer.go | 27 ++-- mailer/mailer_test.go | 8 +- middleware/middleware.go | 14 +- models/campaign.go | 22 ++- models/campaign_test.go | 26 +++ models/email_request.go | 4 +- models/email_request_test.go | 16 +- models/group.go | 10 +- models/maillog.go | 10 +- models/maillog_test.go | 18 +-- models/models.go | 12 +- models/models_test.go | 15 +- models/page.go | 2 +- models/smtp.go | 2 +- models/smtp_test.go | 8 +- models/template.go | 10 +- util/util_test.go | 10 +- worker/worker.go | 49 ++++-- worker/worker_test.go | 19 ++- 31 files changed, 554 insertions(+), 520 deletions(-) delete mode 100644 controllers/static.go delete mode 100644 controllers/static_test.go diff --git a/config/config.go b/config/config.go index c1589837..ba81e2f1 100644 --- a/config/config.go +++ b/config/config.go @@ -38,9 +38,6 @@ type Config struct { Logging LoggingConfig `json:"logging"` } -// Conf contains the initialized configuration struct -var Conf Config - // Version contains the current gophish version var Version = "" @@ -48,19 +45,20 @@ var Version = "" const ServerName = "gophish" // LoadConfig loads the configuration from the specified filepath -func LoadConfig(filepath string) error { +func LoadConfig(filepath string) (*Config, error) { // Get the config file configFile, err := ioutil.ReadFile(filepath) if err != nil { - return err + return nil, err } - err = json.Unmarshal(configFile, &Conf) + config := &Config{} + err = json.Unmarshal(configFile, config) if err != nil { - return err + return nil, err } // Choosing the migrations directory based on the database used. - Conf.MigrationsPath = Conf.MigrationsPath + Conf.DBName + config.MigrationsPath = config.MigrationsPath + config.DBName // Explicitly set the TestFlag to false to prevent config.json overrides - Conf.TestFlag = false - return nil + config.TestFlag = false + return config, nil } diff --git a/config/config_test.go b/config/config_test.go index 7e16b81a..26f4fd4f 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -48,18 +48,18 @@ func (s *ConfigSuite) TestLoadConfig() { _, err := s.ConfigFile.Write(validConfig) s.Nil(err) // Load the valid config - err = LoadConfig(s.ConfigFile.Name()) + conf, err := LoadConfig(s.ConfigFile.Name()) s.Nil(err) - expectedConfig := Config{} + expectedConfig := &Config{} err = json.Unmarshal(validConfig, &expectedConfig) s.Nil(err) expectedConfig.MigrationsPath = expectedConfig.MigrationsPath + expectedConfig.DBName expectedConfig.TestFlag = false - s.Equal(expectedConfig, Conf) + s.Equal(expectedConfig, conf) // Load an invalid config - err = LoadConfig("bogusfile") + conf, err = LoadConfig("bogusfile") s.NotNil(err) } diff --git a/controllers/api.go b/controllers/api.go index 52abfce0..bd6d3938 100644 --- a/controllers/api.go +++ b/controllers/api.go @@ -17,23 +17,14 @@ import ( log "github.com/gophish/gophish/logger" "github.com/gophish/gophish/models" "github.com/gophish/gophish/util" - "github.com/gophish/gophish/worker" "github.com/gorilla/mux" "github.com/jinzhu/gorm" "github.com/jordan-wright/email" "github.com/sirupsen/logrus" ) -// Worker is the worker that processes phishing events and updates campaigns. -var Worker *worker.Worker - -func init() { - Worker = worker.New() - go Worker.Start() -} - // API (/api/reset) resets a user's API key -func API_Reset(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Reset(w http.ResponseWriter, r *http.Request) { switch { case r.Method == "POST": u := ctx.Get(r, "user").(models.User) @@ -49,7 +40,7 @@ func API_Reset(w http.ResponseWriter, r *http.Request) { // API_Campaigns returns a list of campaigns if requested via GET. // If requested via POST, API_Campaigns creates a new campaign and returns a reference to it. -func API_Campaigns(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Campaigns(w http.ResponseWriter, r *http.Request) { switch { case r.Method == "GET": cs, err := models.GetCampaigns(ctx.Get(r, "user_id").(int64)) @@ -74,14 +65,14 @@ func API_Campaigns(w http.ResponseWriter, r *http.Request) { // If the campaign is scheduled to launch immediately, send it to the worker. // Otherwise, the worker will pick it up at the scheduled time if c.Status == models.CAMPAIGN_IN_PROGRESS { - go Worker.LaunchCampaign(c) + go as.worker.LaunchCampaign(c) } JSONResponse(w, c, http.StatusCreated) } } // API_Campaigns_Summary returns the summary for the current user's campaigns -func API_Campaigns_Summary(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Campaigns_Summary(w http.ResponseWriter, r *http.Request) { switch { case r.Method == "GET": cs, err := models.GetCampaignSummaries(ctx.Get(r, "user_id").(int64)) @@ -96,7 +87,7 @@ func API_Campaigns_Summary(w http.ResponseWriter, r *http.Request) { // API_Campaigns_Id returns details about the requested campaign. If the campaign is not // valid, API_Campaigns_Id returns null. -func API_Campaigns_Id(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Campaigns_Id(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id, _ := strconv.ParseInt(vars["id"], 0, 64) c, err := models.GetCampaign(id, ctx.Get(r, "user_id").(int64)) @@ -120,7 +111,7 @@ func API_Campaigns_Id(w http.ResponseWriter, r *http.Request) { // API_Campaigns_Id_Results returns just the results for a given campaign to // significantly reduce the information returned. -func API_Campaigns_Id_Results(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Campaigns_Id_Results(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id, _ := strconv.ParseInt(vars["id"], 0, 64) cr, err := models.GetCampaignResults(id, ctx.Get(r, "user_id").(int64)) @@ -136,7 +127,7 @@ func API_Campaigns_Id_Results(w http.ResponseWriter, r *http.Request) { } // API_Campaigns_Id_Summary returns just the summary for a given campaign. -func API_Campaign_Id_Summary(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Campaign_Id_Summary(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id, _ := strconv.ParseInt(vars["id"], 0, 64) switch { @@ -157,7 +148,7 @@ func API_Campaign_Id_Summary(w http.ResponseWriter, r *http.Request) { // API_Campaigns_Id_Complete effectively "ends" a campaign. // Future phishing emails clicked will return a simple "404" page. -func API_Campaigns_Id_Complete(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Campaigns_Id_Complete(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id, _ := strconv.ParseInt(vars["id"], 0, 64) switch { @@ -173,7 +164,7 @@ func API_Campaigns_Id_Complete(w http.ResponseWriter, r *http.Request) { // API_Groups returns a list of groups if requested via GET. // If requested via POST, API_Groups creates a new group and returns a reference to it. -func API_Groups(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Groups(w http.ResponseWriter, r *http.Request) { switch { case r.Method == "GET": gs, err := models.GetGroups(ctx.Get(r, "user_id").(int64)) @@ -208,7 +199,7 @@ func API_Groups(w http.ResponseWriter, r *http.Request) { } // API_Groups_Summary returns a summary of the groups owned by the current user. -func API_Groups_Summary(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Groups_Summary(w http.ResponseWriter, r *http.Request) { switch { case r.Method == "GET": gs, err := models.GetGroupSummaries(ctx.Get(r, "user_id").(int64)) @@ -223,7 +214,7 @@ func API_Groups_Summary(w http.ResponseWriter, r *http.Request) { // API_Groups_Id returns details about the requested group. // If the group is not valid, API_Groups_Id returns null. -func API_Groups_Id(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Groups_Id(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id, _ := strconv.ParseInt(vars["id"], 0, 64) g, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64)) @@ -261,7 +252,7 @@ func API_Groups_Id(w http.ResponseWriter, r *http.Request) { } // API_Groups_Id_Summary returns a summary of the groups owned by the current user. -func API_Groups_Id_Summary(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Groups_Id_Summary(w http.ResponseWriter, r *http.Request) { switch { case r.Method == "GET": vars := mux.Vars(r) @@ -276,7 +267,7 @@ func API_Groups_Id_Summary(w http.ResponseWriter, r *http.Request) { } // API_Templates handles the functionality for the /api/templates endpoint -func API_Templates(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Templates(w http.ResponseWriter, r *http.Request) { switch { case r.Method == "GET": ts, err := models.GetTemplates(ctx.Get(r, "user_id").(int64)) @@ -319,7 +310,7 @@ func API_Templates(w http.ResponseWriter, r *http.Request) { } // API_Templates_Id handles the functions for the /api/templates/:id endpoint -func API_Templates_Id(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Templates_Id(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id, _ := strconv.ParseInt(vars["id"], 0, 64) t, err := models.GetTemplate(id, ctx.Get(r, "user_id").(int64)) @@ -359,7 +350,7 @@ func API_Templates_Id(w http.ResponseWriter, r *http.Request) { } // API_Pages handles requests for the /api/pages/ endpoint -func API_Pages(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Pages(w http.ResponseWriter, r *http.Request) { switch { case r.Method == "GET": ps, err := models.GetPages(ctx.Get(r, "user_id").(int64)) @@ -396,7 +387,7 @@ func API_Pages(w http.ResponseWriter, r *http.Request) { // API_Pages_Id contains functions to handle the GET'ing, DELETE'ing, and PUT'ing // of a Page object -func API_Pages_Id(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Pages_Id(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id, _ := strconv.ParseInt(vars["id"], 0, 64) p, err := models.GetPage(id, ctx.Get(r, "user_id").(int64)) @@ -436,7 +427,7 @@ func API_Pages_Id(w http.ResponseWriter, r *http.Request) { } // API_SMTP handles requests for the /api/smtp/ endpoint -func API_SMTP(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_SMTP(w http.ResponseWriter, r *http.Request) { switch { case r.Method == "GET": ss, err := models.GetSMTPs(ctx.Get(r, "user_id").(int64)) @@ -473,7 +464,7 @@ func API_SMTP(w http.ResponseWriter, r *http.Request) { // API_SMTP_Id contains functions to handle the GET'ing, DELETE'ing, and PUT'ing // of a SMTP object -func API_SMTP_Id(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_SMTP_Id(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id, _ := strconv.ParseInt(vars["id"], 0, 64) s, err := models.GetSMTP(id, ctx.Get(r, "user_id").(int64)) @@ -518,7 +509,7 @@ func API_SMTP_Id(w http.ResponseWriter, r *http.Request) { } // API_Import_Group imports a CSV of group members -func API_Import_Group(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Import_Group(w http.ResponseWriter, r *http.Request) { ts, err := util.ParseCSV(r) if err != nil { JSONResponse(w, models.Response{Success: false, Message: "Error parsing CSV"}, http.StatusInternalServerError) @@ -530,7 +521,7 @@ func API_Import_Group(w http.ResponseWriter, r *http.Request) { // API_Import_Email allows for the importing of email. // Returns a Message object -func API_Import_Email(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Import_Email(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { JSONResponse(w, models.Response{Success: false, Message: "Method not allowed"}, http.StatusBadRequest) return @@ -579,7 +570,7 @@ func API_Import_Email(w http.ResponseWriter, r *http.Request) { // API_Import_Site allows for the importing of HTML from a website // Without "include_resources" set, it will merely place a "base" tag // so that all resources can be loaded relative to the given URL. -func API_Import_Site(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Import_Site(w http.ResponseWriter, r *http.Request) { cr := cloneRequest{} if r.Method != "POST" { JSONResponse(w, models.Response{Success: false, Message: "Method not allowed"}, http.StatusBadRequest) @@ -637,7 +628,7 @@ func API_Import_Site(w http.ResponseWriter, r *http.Request) { // API_Send_Test_Email sends a test email using the template name // and Target given. -func API_Send_Test_Email(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) API_Send_Test_Email(w http.ResponseWriter, r *http.Request) { s := &models.EmailRequest{ ErrorChan: make(chan error), UserId: ctx.Get(r, "user_id").(int64), @@ -735,7 +726,7 @@ func API_Send_Test_Email(w http.ResponseWriter, r *http.Request) { } } // Send the test email - err = Worker.SendTestEmail(s) + err = as.worker.SendTestEmail(s) if err != nil { log.Error(err) JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusInternalServerError) diff --git a/controllers/api_test.go b/controllers/api_test.go index 8e024ece..cc0684c7 100644 --- a/controllers/api_test.go +++ b/controllers/api_test.go @@ -11,41 +11,42 @@ import ( "github.com/gophish/gophish/config" "github.com/gophish/gophish/models" - "github.com/gorilla/handlers" "github.com/stretchr/testify/suite" ) // ControllersSuite is a suite of tests to cover API related functions type ControllersSuite struct { suite.Suite - ApiKey string + ApiKey string + config *config.Config + adminServer *httptest.Server + phishServer *httptest.Server } -// as is the Admin Server for our API calls -var as *httptest.Server = httptest.NewUnstartedServer(handlers.CombinedLoggingHandler(os.Stdout, CreateAdminRouter())) - -// ps is the Phishing Server -var ps *httptest.Server = httptest.NewUnstartedServer(handlers.CombinedLoggingHandler(os.Stdout, CreatePhishingRouter())) - func (s *ControllersSuite) SetupSuite() { - config.Conf.DBName = "sqlite3" - config.Conf.DBPath = ":memory:" - config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/" - err := models.Setup() + conf := &config.Config{ + DBName: "sqlite3", + DBPath: ":memory:", + MigrationsPath: "../db/db_sqlite3/migrations/", + } + err := models.Setup(conf) if err != nil { s.T().Fatalf("Failed creating database: %v", err) } + s.config = conf s.Nil(err) // Setup the admin server for use in testing - as.Config.Addr = config.Conf.AdminConf.ListenURL - as.Start() + s.adminServer = httptest.NewUnstartedServer(NewAdminServer(s.config.AdminConf).server.Handler) + s.adminServer.Config.Addr = s.config.AdminConf.ListenURL + s.adminServer.Start() // Get the API key to use for these tests u, err := models.GetUser(1) s.Nil(err) s.ApiKey = u.ApiKey // Start the phishing server - ps.Config.Addr = config.Conf.PhishConf.ListenURL - ps.Start() + s.phishServer = httptest.NewUnstartedServer(NewPhishingServer(s.config.PhishConf).server.Handler) + s.phishServer.Config.Addr = s.config.PhishConf.ListenURL + s.phishServer.Start() // Move our cwd up to the project root for help with resolving // static assets err = os.Chdir("../") @@ -103,21 +104,21 @@ func (s *ControllersSuite) SetupTest() { } func (s *ControllersSuite) TestRequireAPIKey() { - resp, err := http.Post(fmt.Sprintf("%s/api/import/site", as.URL), "application/json", nil) + resp, err := http.Post(fmt.Sprintf("%s/api/import/site", s.adminServer.URL), "application/json", nil) s.Nil(err) defer resp.Body.Close() - s.Equal(resp.StatusCode, http.StatusBadRequest) + s.Equal(resp.StatusCode, http.StatusUnauthorized) } func (s *ControllersSuite) TestInvalidAPIKey() { - resp, err := http.Get(fmt.Sprintf("%s/api/groups/?api_key=%s", as.URL, "bogus-api-key")) + resp, err := http.Get(fmt.Sprintf("%s/api/groups/?api_key=%s", s.adminServer.URL, "bogus-api-key")) s.Nil(err) defer resp.Body.Close() - s.Equal(resp.StatusCode, http.StatusBadRequest) + s.Equal(resp.StatusCode, http.StatusUnauthorized) } func (s *ControllersSuite) TestBearerToken() { - req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/groups/", as.URL), nil) + req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/groups/", s.adminServer.URL), nil) s.Nil(err) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.ApiKey)) resp, err := http.DefaultClient.Do(req) @@ -133,7 +134,7 @@ func (s *ControllersSuite) TestSiteImportBaseHref() { })) hr := fmt.Sprintf("\n", ts.URL) defer ts.Close() - resp, err := http.Post(fmt.Sprintf("%s/api/import/site?api_key=%s", as.URL, s.ApiKey), "application/json", + resp, err := http.Post(fmt.Sprintf("%s/api/import/site?api_key=%s", s.adminServer.URL, s.ApiKey), "application/json", bytes.NewBuffer([]byte(fmt.Sprintf(` { "url" : "%s", @@ -150,8 +151,8 @@ func (s *ControllersSuite) TestSiteImportBaseHref() { func (s *ControllersSuite) TearDownSuite() { // Tear down the admin and phishing servers - as.Close() - ps.Close() + s.adminServer.Close() + s.phishServer.Close() } func TestControllerSuite(t *testing.T) { diff --git a/controllers/phish.go b/controllers/phish.go index 04165829..e1996d3d 100644 --- a/controllers/phish.go +++ b/controllers/phish.go @@ -1,6 +1,8 @@ package controllers import ( + "compress/gzip" + "context" "errors" "fmt" "net" @@ -8,11 +10,15 @@ import ( "strings" "time" + "github.com/NYTimes/gziphandler" "github.com/gophish/gophish/config" ctx "github.com/gophish/gophish/context" log "github.com/gophish/gophish/logger" "github.com/gophish/gophish/models" + "github.com/gophish/gophish/util" + "github.com/gorilla/handlers" "github.com/gorilla/mux" + "github.com/jordan-wright/unindexed" ) // ErrInvalidRequest is thrown when a request with an invalid structure is @@ -35,22 +41,91 @@ type TransparencyResponse struct { // to return a transparency response. const TransparencySuffix = "+" -// CreatePhishingRouter creates the router that handles phishing connections. -func CreatePhishingRouter() http.Handler { - router := mux.NewRouter() - fileServer := http.FileServer(UnindexedFileSystem{http.Dir("./static/endpoint/")}) - router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", fileServer)) - router.HandleFunc("/track", PhishTracker) - router.HandleFunc("/robots.txt", RobotsHandler) - router.HandleFunc("/{path:.*}/track", PhishTracker) - router.HandleFunc("/{path:.*}/report", PhishReporter) - router.HandleFunc("/report", PhishReporter) - router.HandleFunc("/{path:.*}", PhishHandler) - return router +// PhishingServerOption is a functional option that is used to configure the +// the phishing server +type PhishingServerOption func(*PhishingServer) + +// PhishingServer is an HTTP server that implements the campaign event +// handlers, such as email open tracking, click tracking, and more. +type PhishingServer struct { + server *http.Server + config config.PhishServer + contactAddress string } -// PhishTracker tracks emails as they are opened, updating the status for the given Result -func PhishTracker(w http.ResponseWriter, r *http.Request) { +// NewPhishingServer returns a new instance of the phishing server with +// provided options applied. +func NewPhishingServer(config config.PhishServer, options ...PhishingServerOption) *PhishingServer { + defaultServer := &http.Server{ + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + Addr: config.ListenURL, + } + ps := &PhishingServer{ + server: defaultServer, + config: config, + } + for _, opt := range options { + opt(ps) + } + ps.registerRoutes() + return ps +} + +// WithContactAddress sets the contact address used by the transparency +// handlers +func WithContactAddress(addr string) PhishingServerOption { + return func(ps *PhishingServer) { + ps.contactAddress = addr + } +} + +// Start launches the phishing server, listening on the configured address. +func (ps *PhishingServer) Start() error { + if ps.config.UseTLS { + err := util.CheckAndCreateSSL(ps.config.CertPath, ps.config.KeyPath) + if err != nil { + log.Fatal(err) + return err + } + log.Infof("Starting phishing server at https://%s", ps.config.ListenURL) + return ps.server.ListenAndServeTLS(ps.config.CertPath, ps.config.KeyPath) + } + // If TLS isn't configured, just listen on HTTP + log.Infof("Starting phishing server at http://%s", ps.config.ListenURL) + return ps.server.ListenAndServe() +} + +// Shutdown attempts to gracefully shutdown the server. +func (ps *PhishingServer) Shutdown() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + return ps.server.Shutdown(ctx) +} + +// CreatePhishingRouter creates the router that handles phishing connections. +func (ps *PhishingServer) registerRoutes() { + router := mux.NewRouter() + fileServer := http.FileServer(unindexed.Dir("./static/endpoint/")) + router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", fileServer)) + router.HandleFunc("/track", ps.TrackHandler) + router.HandleFunc("/robots.txt", ps.RobotsHandler) + router.HandleFunc("/{path:.*}/track", ps.TrackHandler) + router.HandleFunc("/{path:.*}/report", ps.ReportHandler) + router.HandleFunc("/report", ps.ReportHandler) + router.HandleFunc("/{path:.*}", ps.PhishHandler) + + // Setup GZIP compression + gzipWrapper, _ := gziphandler.NewGzipLevelHandler(gzip.BestCompression) + phishHandler := gzipWrapper(router) + + // Setup logging + phishHandler = handlers.CombinedLoggingHandler(log.Writer(), phishHandler) + ps.server.Handler = router +} + +// TrackHandler tracks emails as they are opened, updating the status for the given Result +func (ps *PhishingServer) TrackHandler(w http.ResponseWriter, r *http.Request) { err, r := setupContext(r) if err != nil { // Log the error if it wasn't something we can safely ignore @@ -71,7 +146,7 @@ func PhishTracker(w http.ResponseWriter, r *http.Request) { // Check for a transparency request if strings.HasSuffix(rid, TransparencySuffix) { - TransparencyHandler(w, r) + ps.TransparencyHandler(w, r) return } @@ -82,8 +157,8 @@ func PhishTracker(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, "static/images/pixel.png") } -// PhishReporter tracks emails as they are reported, updating the status for the given Result -func PhishReporter(w http.ResponseWriter, r *http.Request) { +// ReportHandler tracks emails as they are reported, updating the status for the given Result +func (ps *PhishingServer) ReportHandler(w http.ResponseWriter, r *http.Request) { err, r := setupContext(r) if err != nil { // Log the error if it wasn't something we can safely ignore @@ -104,7 +179,7 @@ func PhishReporter(w http.ResponseWriter, r *http.Request) { // Check for a transparency request if strings.HasSuffix(rid, TransparencySuffix) { - TransparencyHandler(w, r) + ps.TransparencyHandler(w, r) return } @@ -117,7 +192,7 @@ func PhishReporter(w http.ResponseWriter, r *http.Request) { // PhishHandler handles incoming client connections and registers the associated actions performed // (such as clicked link, etc.) -func PhishHandler(w http.ResponseWriter, r *http.Request) { +func (ps *PhishingServer) PhishHandler(w http.ResponseWriter, r *http.Request) { err, r := setupContext(r) if err != nil { // Log the error if it wasn't something we can safely ignore @@ -152,7 +227,7 @@ func PhishHandler(w http.ResponseWriter, r *http.Request) { // Check for a transparency request if strings.HasSuffix(rid, TransparencySuffix) { - TransparencyHandler(w, r) + ps.TransparencyHandler(w, r) return } @@ -211,23 +286,24 @@ func renderPhishResponse(w http.ResponseWriter, r *http.Request, ptx models.Phis } // RobotsHandler prevents search engines, etc. from indexing phishing materials -func RobotsHandler(w http.ResponseWriter, r *http.Request) { +func (ps *PhishingServer) RobotsHandler(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "User-agent: *\nDisallow: /") } // TransparencyHandler returns a TransparencyResponse for the provided result // and campaign. -func TransparencyHandler(w http.ResponseWriter, r *http.Request) { +func (ps *PhishingServer) TransparencyHandler(w http.ResponseWriter, r *http.Request) { rs := ctx.Get(r, "result").(models.Result) tr := &TransparencyResponse{ Server: config.ServerName, SendDate: rs.SendDate, - ContactAddress: config.Conf.ContactAddress, + ContactAddress: ps.contactAddress, } JSONResponse(w, tr, http.StatusOK) } -// setupContext handles some of the administrative work around receiving a new request, such as checking the result ID, the campaign, etc. +// setupContext handles some of the administrative work around receiving a new +// request, such as checking the result ID, the campaign, etc. func setupContext(r *http.Request) (error, *http.Request) { err := r.ParseForm() if err != nil { diff --git a/controllers/phish_test.go b/controllers/phish_test.go index 59425c8a..dca961a6 100644 --- a/controllers/phish_test.go +++ b/controllers/phish_test.go @@ -38,7 +38,7 @@ func (s *ControllersSuite) getFirstEmailRequest() models.EmailRequest { } func (s *ControllersSuite) openEmail(rid string) { - resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", ps.URL, models.RecipientParameter, rid)) + resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) s.Nil(err) defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) @@ -49,19 +49,19 @@ func (s *ControllersSuite) openEmail(rid string) { } func (s *ControllersSuite) reportedEmail(rid string) { - resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", ps.URL, models.RecipientParameter, rid)) + resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) s.Nil(err) s.Equal(resp.StatusCode, http.StatusNoContent) } func (s *ControllersSuite) reportEmail404(rid string) { - resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", ps.URL, models.RecipientParameter, rid)) + resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) s.Nil(err) s.Equal(resp.StatusCode, http.StatusNotFound) } func (s *ControllersSuite) openEmail404(rid string) { - resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", ps.URL, models.RecipientParameter, rid)) + resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) s.Nil(err) defer resp.Body.Close() s.Nil(err) @@ -69,7 +69,7 @@ func (s *ControllersSuite) openEmail404(rid string) { } func (s *ControllersSuite) clickLink(rid string, expectedHTML string) { - resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", ps.URL, models.RecipientParameter, rid)) + resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) s.Nil(err) defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) @@ -79,7 +79,7 @@ func (s *ControllersSuite) clickLink(rid string, expectedHTML string) { } func (s *ControllersSuite) clickLink404(rid string) { - resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", ps.URL, models.RecipientParameter, rid)) + resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, rid)) s.Nil(err) defer resp.Body.Close() s.Nil(err) @@ -87,14 +87,14 @@ func (s *ControllersSuite) clickLink404(rid string) { } func (s *ControllersSuite) transparencyRequest(r models.Result, rid, path string) { - resp, err := http.Get(fmt.Sprintf("%s%s?%s=%s", ps.URL, path, models.RecipientParameter, rid)) + resp, err := http.Get(fmt.Sprintf("%s%s?%s=%s", s.phishServer.URL, path, models.RecipientParameter, rid)) s.Nil(err) defer resp.Body.Close() s.Equal(resp.StatusCode, http.StatusOK) tr := &TransparencyResponse{} err = json.NewDecoder(resp.Body).Decode(tr) s.Nil(err) - s.Equal(tr.ContactAddress, config.Conf.ContactAddress) + s.Equal(tr.ContactAddress, s.config.ContactAddress) s.Equal(tr.SendDate, r.SendDate) s.Equal(tr.Server, config.ServerName) } @@ -146,11 +146,11 @@ func (s *ControllersSuite) TestClickedPhishingLinkAfterOpen() { } func (s *ControllersSuite) TestNoRecipientID() { - resp, err := http.Get(fmt.Sprintf("%s/track", ps.URL)) + resp, err := http.Get(fmt.Sprintf("%s/track", s.phishServer.URL)) s.Nil(err) s.Equal(resp.StatusCode, http.StatusNotFound) - resp, err = http.Get(ps.URL) + resp, err = http.Get(s.phishServer.URL) s.Nil(err) s.Equal(resp.StatusCode, http.StatusNotFound) } @@ -183,7 +183,7 @@ func (s *ControllersSuite) TestCompletedCampaignClick() { func (s *ControllersSuite) TestRobotsHandler() { expected := []byte("User-agent: *\nDisallow: /\n") - resp, err := http.Get(fmt.Sprintf("%s/robots.txt", ps.URL)) + resp, err := http.Get(fmt.Sprintf("%s/robots.txt", s.phishServer.URL)) s.Nil(err) s.Equal(resp.StatusCode, http.StatusOK) defer resp.Body.Close() @@ -259,7 +259,7 @@ func (s *ControllersSuite) TestRedirectTemplating() { }, } result := campaign.Results[0] - resp, err := client.PostForm(fmt.Sprintf("%s/?%s=%s", ps.URL, models.RecipientParameter, result.RId), url.Values{"username": {"test"}, "password": {"test"}}) + resp, err := client.PostForm(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, result.RId), url.Values{"username": {"test"}, "password": {"test"}}) s.Nil(err) defer resp.Body.Close() s.Equal(http.StatusFound, resp.StatusCode) diff --git a/controllers/route.go b/controllers/route.go index ff31ece7..c27499d8 100644 --- a/controllers/route.go +++ b/controllers/route.go @@ -1,72 +1,153 @@ package controllers import ( - "fmt" + "compress/gzip" + "context" "html/template" "net/http" "net/url" + "time" + "github.com/NYTimes/gziphandler" "github.com/gophish/gophish/auth" "github.com/gophish/gophish/config" ctx "github.com/gophish/gophish/context" log "github.com/gophish/gophish/logger" mid "github.com/gophish/gophish/middleware" "github.com/gophish/gophish/models" + "github.com/gophish/gophish/util" + "github.com/gophish/gophish/worker" "github.com/gorilla/csrf" + "github.com/gorilla/handlers" "github.com/gorilla/mux" "github.com/gorilla/sessions" + "github.com/jordan-wright/unindexed" ) -// CreateAdminRouter creates the routes for handling requests to the web interface. +// AdminServerOption is a functional option that is used to configure the +// admin server +type AdminServerOption func(*AdminServer) + +// AdminServer is an HTTP server that implements the administrative Gophish +// handlers, including the dashboard and REST API. +type AdminServer struct { + server *http.Server + worker worker.Worker + config config.AdminServer +} + +// WithWorker is an option that sets the background worker. +func WithWorker(w worker.Worker) AdminServerOption { + return func(as *AdminServer) { + as.worker = w + } +} + +// NewAdminServer returns a new instance of the AdminServer with the +// provided config and options applied. +func NewAdminServer(config config.AdminServer, options ...AdminServerOption) *AdminServer { + defaultWorker, _ := worker.New() + defaultServer := &http.Server{ + ReadTimeout: 10 * time.Second, + Addr: config.ListenURL, + } + as := &AdminServer{ + worker: defaultWorker, + server: defaultServer, + config: config, + } + for _, opt := range options { + opt(as) + } + as.registerRoutes() + return as +} + +// Start launches the admin server, listening on the configured address. +func (as *AdminServer) Start() error { + if as.worker != nil { + go as.worker.Start() + } + if as.config.UseTLS { + err := util.CheckAndCreateSSL(as.config.CertPath, as.config.KeyPath) + if err != nil { + log.Fatal(err) + return err + } + log.Infof("Starting admin server at https://%s", as.config.ListenURL) + return as.server.ListenAndServeTLS(as.config.CertPath, as.config.KeyPath) + } + // If TLS isn't configured, just listen on HTTP + log.Infof("Starting admin server at http://%s", as.config.ListenURL) + return as.server.ListenAndServe() +} + +// Shutdown attempts to gracefully shutdown the server. +func (as *AdminServer) Shutdown() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + return as.server.Shutdown(ctx) +} + +// SetupAdminRoutes creates the routes for handling requests to the web interface. // This function returns an http.Handler to be used in http.ListenAndServe(). -func CreateAdminRouter() http.Handler { +func (as *AdminServer) registerRoutes() { router := mux.NewRouter() // Base Front-end routes - router.HandleFunc("/", Use(Base, mid.RequireLogin)) - router.HandleFunc("/login", Login) - router.HandleFunc("/logout", Use(Logout, mid.RequireLogin)) - router.HandleFunc("/campaigns", Use(Campaigns, mid.RequireLogin)) - router.HandleFunc("/campaigns/{id:[0-9]+}", Use(CampaignID, mid.RequireLogin)) - router.HandleFunc("/templates", Use(Templates, mid.RequireLogin)) - router.HandleFunc("/users", Use(Users, mid.RequireLogin)) - router.HandleFunc("/landing_pages", Use(LandingPages, mid.RequireLogin)) - router.HandleFunc("/sending_profiles", Use(SendingProfiles, mid.RequireLogin)) - router.HandleFunc("/register", Use(Register, mid.RequireLogin)) - router.HandleFunc("/settings", Use(Settings, mid.RequireLogin)) + router.HandleFunc("/", Use(as.Base, mid.RequireLogin)) + router.HandleFunc("/login", as.Login) + router.HandleFunc("/logout", Use(as.Logout, mid.RequireLogin)) + router.HandleFunc("/campaigns", Use(as.Campaigns, mid.RequireLogin)) + router.HandleFunc("/campaigns/{id:[0-9]+}", Use(as.CampaignID, mid.RequireLogin)) + router.HandleFunc("/templates", Use(as.Templates, mid.RequireLogin)) + router.HandleFunc("/users", Use(as.Users, mid.RequireLogin)) + router.HandleFunc("/landing_pages", Use(as.LandingPages, mid.RequireLogin)) + router.HandleFunc("/sending_profiles", Use(as.SendingProfiles, mid.RequireLogin)) + router.HandleFunc("/register", Use(as.Register, mid.RequireLogin)) + router.HandleFunc("/settings", Use(as.Settings, mid.RequireLogin)) // Create the API routes api := router.PathPrefix("/api").Subrouter() api = api.StrictSlash(true) - api.HandleFunc("/reset", Use(API_Reset, mid.RequireAPIKey)) - api.HandleFunc("/campaigns/", Use(API_Campaigns, mid.RequireAPIKey)) - api.HandleFunc("/campaigns/summary", Use(API_Campaigns_Summary, mid.RequireAPIKey)) - api.HandleFunc("/campaigns/{id:[0-9]+}", Use(API_Campaigns_Id, mid.RequireAPIKey)) - api.HandleFunc("/campaigns/{id:[0-9]+}/results", Use(API_Campaigns_Id_Results, mid.RequireAPIKey)) - api.HandleFunc("/campaigns/{id:[0-9]+}/summary", Use(API_Campaign_Id_Summary, mid.RequireAPIKey)) - api.HandleFunc("/campaigns/{id:[0-9]+}/complete", Use(API_Campaigns_Id_Complete, mid.RequireAPIKey)) - api.HandleFunc("/groups/", Use(API_Groups, mid.RequireAPIKey)) - api.HandleFunc("/groups/summary", Use(API_Groups_Summary, mid.RequireAPIKey)) - api.HandleFunc("/groups/{id:[0-9]+}", Use(API_Groups_Id, mid.RequireAPIKey)) - api.HandleFunc("/groups/{id:[0-9]+}/summary", Use(API_Groups_Id_Summary, mid.RequireAPIKey)) - api.HandleFunc("/templates/", Use(API_Templates, mid.RequireAPIKey)) - api.HandleFunc("/templates/{id:[0-9]+}", Use(API_Templates_Id, mid.RequireAPIKey)) - api.HandleFunc("/pages/", Use(API_Pages, mid.RequireAPIKey)) - api.HandleFunc("/pages/{id:[0-9]+}", Use(API_Pages_Id, mid.RequireAPIKey)) - api.HandleFunc("/smtp/", Use(API_SMTP, mid.RequireAPIKey)) - api.HandleFunc("/smtp/{id:[0-9]+}", Use(API_SMTP_Id, mid.RequireAPIKey)) - api.HandleFunc("/util/send_test_email", Use(API_Send_Test_Email, mid.RequireAPIKey)) - api.HandleFunc("/import/group", Use(API_Import_Group, mid.RequireAPIKey)) - api.HandleFunc("/import/email", Use(API_Import_Email, mid.RequireAPIKey)) - api.HandleFunc("/import/site", Use(API_Import_Site, mid.RequireAPIKey)) + api.Use(mid.RequireAPIKey) + api.HandleFunc("/reset", as.API_Reset) + api.HandleFunc("/campaigns/", as.API_Campaigns) + api.HandleFunc("/campaigns/summary", as.API_Campaigns_Summary) + api.HandleFunc("/campaigns/{id:[0-9]+}", as.API_Campaigns_Id) + api.HandleFunc("/campaigns/{id:[0-9]+}/results", as.API_Campaigns_Id_Results) + api.HandleFunc("/campaigns/{id:[0-9]+}/summary", as.API_Campaign_Id_Summary) + api.HandleFunc("/campaigns/{id:[0-9]+}/complete", as.API_Campaigns_Id_Complete) + api.HandleFunc("/groups/", as.API_Groups) + api.HandleFunc("/groups/summary", as.API_Groups_Summary) + api.HandleFunc("/groups/{id:[0-9]+}", as.API_Groups_Id) + api.HandleFunc("/groups/{id:[0-9]+}/summary", as.API_Groups_Id_Summary) + api.HandleFunc("/templates/", as.API_Templates) + api.HandleFunc("/templates/{id:[0-9]+}", as.API_Templates_Id) + api.HandleFunc("/pages/", as.API_Pages) + api.HandleFunc("/pages/{id:[0-9]+}", as.API_Pages_Id) + api.HandleFunc("/smtp/", as.API_SMTP) + api.HandleFunc("/smtp/{id:[0-9]+}", as.API_SMTP_Id) + api.HandleFunc("/util/send_test_email", as.API_Send_Test_Email) + api.HandleFunc("/import/group", as.API_Import_Group) + api.HandleFunc("/import/email", as.API_Import_Email) + api.HandleFunc("/import/site", as.API_Import_Site) // Setup static file serving - router.PathPrefix("/").Handler(http.FileServer(UnindexedFileSystem{http.Dir("./static/")})) + router.PathPrefix("/").Handler(http.FileServer(unindexed.Dir("./static/"))) // Setup CSRF Protection csrfHandler := csrf.Protect([]byte(auth.GenerateSecureKey()), csrf.FieldName("csrf_token"), - csrf.Secure(config.Conf.AdminConf.UseTLS)) - csrfRouter := csrfHandler(router) - return Use(csrfRouter.ServeHTTP, mid.CSRFExceptions, mid.GetContext) + csrf.Secure(as.config.UseTLS)) + adminHandler := csrfHandler(router) + adminHandler = Use(adminHandler.ServeHTTP, mid.CSRFExceptions, mid.GetContext) + + // Setup GZIP compression + gzipWrapper, _ := gziphandler.NewGzipLevelHandler(gzip.BestCompression) + adminHandler = gzipWrapper(adminHandler) + + // Setup logging + adminHandler = handlers.CombinedLoggingHandler(log.Writer(), adminHandler) + as.server.Handler = adminHandler } // Use allows us to stack middleware to process the request @@ -78,16 +159,29 @@ func Use(handler http.HandlerFunc, mid ...func(http.Handler) http.HandlerFunc) h return handler } +type templateParams struct { + Title string + Flashes []interface{} + User models.User + Token string + Version string +} + +// newTemplateParams returns the default template parameters for a user and +// the CSRF token. +func newTemplateParams(r *http.Request) templateParams { + return templateParams{ + Token: csrf.Token(r), + User: ctx.Get(r, "user").(models.User), + Version: config.Version, + } +} + // Register creates a new user -func Register(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) Register(w http.ResponseWriter, r *http.Request) { // If it is a post request, attempt to register the account // Now that we are all registered, we can log the user in - params := struct { - Title string - Flashes []interface{} - User models.User - Token string - }{Title: "Register", Token: csrf.Token(r)} + params := templateParams{Title: "Register", Token: csrf.Token(r)} session := ctx.Get(r, "session").(*sessions.Session) switch { case r.Method == "GET": @@ -120,99 +214,60 @@ func Register(w http.ResponseWriter, r *http.Request) { } // Base handles the default path and template execution -func Base(w http.ResponseWriter, r *http.Request) { - params := struct { - User models.User - Title string - Flashes []interface{} - Token string - }{Title: "Dashboard", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)} +func (as *AdminServer) Base(w http.ResponseWriter, r *http.Request) { + params := newTemplateParams(r) + params.Title = "Dashboard" getTemplate(w, "dashboard").ExecuteTemplate(w, "base", params) } // Campaigns handles the default path and template execution -func Campaigns(w http.ResponseWriter, r *http.Request) { - // Example of using session - will be removed. - params := struct { - User models.User - Title string - Flashes []interface{} - Token string - }{Title: "Campaigns", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)} +func (as *AdminServer) Campaigns(w http.ResponseWriter, r *http.Request) { + params := newTemplateParams(r) + params.Title = "Campaigns" getTemplate(w, "campaigns").ExecuteTemplate(w, "base", params) } // CampaignID handles the default path and template execution -func CampaignID(w http.ResponseWriter, r *http.Request) { - // Example of using session - will be removed. - params := struct { - User models.User - Title string - Flashes []interface{} - Token string - }{Title: "Campaign Results", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)} +func (as *AdminServer) CampaignID(w http.ResponseWriter, r *http.Request) { + params := newTemplateParams(r) + params.Title = "Campaign Results" getTemplate(w, "campaign_results").ExecuteTemplate(w, "base", params) } // Templates handles the default path and template execution -func Templates(w http.ResponseWriter, r *http.Request) { - // Example of using session - will be removed. - params := struct { - User models.User - Title string - Flashes []interface{} - Token string - }{Title: "Email Templates", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)} +func (as *AdminServer) Templates(w http.ResponseWriter, r *http.Request) { + params := newTemplateParams(r) + params.Title = "Email Templates" getTemplate(w, "templates").ExecuteTemplate(w, "base", params) } // Users handles the default path and template execution -func Users(w http.ResponseWriter, r *http.Request) { - // Example of using session - will be removed. - params := struct { - User models.User - Title string - Flashes []interface{} - Token string - }{Title: "Users & Groups", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)} +func (as *AdminServer) Users(w http.ResponseWriter, r *http.Request) { + params := newTemplateParams(r) + params.Title = "Users & Groups" getTemplate(w, "users").ExecuteTemplate(w, "base", params) } // LandingPages handles the default path and template execution -func LandingPages(w http.ResponseWriter, r *http.Request) { - // Example of using session - will be removed. - params := struct { - User models.User - Title string - Flashes []interface{} - Token string - }{Title: "Landing Pages", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)} +func (as *AdminServer) LandingPages(w http.ResponseWriter, r *http.Request) { + params := newTemplateParams(r) + params.Title = "Landing Pages" getTemplate(w, "landing_pages").ExecuteTemplate(w, "base", params) } // SendingProfiles handles the default path and template execution -func SendingProfiles(w http.ResponseWriter, r *http.Request) { - // Example of using session - will be removed. - params := struct { - User models.User - Title string - Flashes []interface{} - Token string - }{Title: "Sending Profiles", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)} +func (as *AdminServer) SendingProfiles(w http.ResponseWriter, r *http.Request) { + params := newTemplateParams(r) + params.Title = "Sending Profiles" getTemplate(w, "sending_profiles").ExecuteTemplate(w, "base", params) } // Settings handles the changing of settings -func Settings(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) Settings(w http.ResponseWriter, r *http.Request) { switch { case r.Method == "GET": - params := struct { - User models.User - Title string - Flashes []interface{} - Token string - Version string - }{Title: "Settings", Version: config.Version, User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)} + params := newTemplateParams(r) + params.Title = "Settings" getTemplate(w, "settings").ExecuteTemplate(w, "base", params) case r.Method == "POST": err := auth.ChangePassword(r) @@ -235,7 +290,7 @@ func Settings(w http.ResponseWriter, r *http.Request) { // Login handles the authentication flow for a user. If credentials are valid, // a session is created -func Login(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) Login(w http.ResponseWriter, r *http.Request) { params := struct { User models.User Title string @@ -289,7 +344,7 @@ func Login(w http.ResponseWriter, r *http.Request) { } // Logout destroys the current user session -func Logout(w http.ResponseWriter, r *http.Request) { +func (as *AdminServer) Logout(w http.ResponseWriter, r *http.Request) { session := ctx.Get(r, "session").(*sessions.Session) delete(session.Values, "id") Flash(w, r, "success", "You have successfully logged out") @@ -297,28 +352,6 @@ func Logout(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/login", 302) } -// Preview allows for the viewing of page html in a separate browser window -func Preview(w http.ResponseWriter, r *http.Request) { - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusBadRequest) - return - } - fmt.Fprintf(w, "%s", r.FormValue("html")) -} - -// Clone takes a URL as a POST parameter and returns the site HTML -func Clone(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusBadRequest) - return - } - if url, ok := vars["url"]; ok { - log.Error(url) - } - http.Error(w, "No URL given.", http.StatusBadRequest) -} - func getTemplate(w http.ResponseWriter, tmpl string) *template.Template { templates := template.New("template") _, err := templates.ParseFiles("templates/base.html", "templates/"+tmpl+".html", "templates/flashes.html") diff --git a/controllers/route_test.go b/controllers/route_test.go index 09d7ec45..c9c7975e 100644 --- a/controllers/route_test.go +++ b/controllers/route_test.go @@ -10,7 +10,7 @@ import ( ) func (s *ControllersSuite) TestLoginCSRF() { - resp, err := http.PostForm(fmt.Sprintf("%s/login", as.URL), + resp, err := http.PostForm(fmt.Sprintf("%s/login", s.adminServer.URL), url.Values{ "username": {"admin"}, "password": {"gophish"}, @@ -21,7 +21,7 @@ func (s *ControllersSuite) TestLoginCSRF() { } func (s *ControllersSuite) TestInvalidCredentials() { - resp, err := http.Get(fmt.Sprintf("%s/login", as.URL)) + resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL)) s.Equal(err, nil) s.Equal(resp.StatusCode, http.StatusOK) @@ -32,7 +32,7 @@ func (s *ControllersSuite) TestInvalidCredentials() { s.Equal(ok, true) client := &http.Client{} - req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", as.URL), strings.NewReader(url.Values{ + req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", s.adminServer.URL), strings.NewReader(url.Values{ "username": {"admin"}, "password": {"invalid"}, "csrf_token": {token}, @@ -48,7 +48,7 @@ func (s *ControllersSuite) TestInvalidCredentials() { } func (s *ControllersSuite) TestSuccessfulLogin() { - resp, err := http.Get(fmt.Sprintf("%s/login", as.URL)) + resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL)) s.Equal(err, nil) s.Equal(resp.StatusCode, http.StatusOK) @@ -59,7 +59,7 @@ func (s *ControllersSuite) TestSuccessfulLogin() { s.Equal(ok, true) client := &http.Client{} - req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", as.URL), strings.NewReader(url.Values{ + req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", s.adminServer.URL), strings.NewReader(url.Values{ "username": {"admin"}, "password": {"gophish"}, "csrf_token": {token}, @@ -76,7 +76,7 @@ func (s *ControllersSuite) TestSuccessfulLogin() { func (s *ControllersSuite) TestSuccessfulRedirect() { next := "/campaigns" - resp, err := http.Get(fmt.Sprintf("%s/login", as.URL)) + resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL)) s.Equal(err, nil) s.Equal(resp.StatusCode, http.StatusOK) @@ -91,7 +91,7 @@ func (s *ControllersSuite) TestSuccessfulRedirect() { return http.ErrUseLastResponse }, } - req, err := http.NewRequest("POST", fmt.Sprintf("%s/login?next=%s", as.URL, next), strings.NewReader(url.Values{ + req, err := http.NewRequest("POST", fmt.Sprintf("%s/login?next=%s", s.adminServer.URL, next), strings.NewReader(url.Values{ "username": {"admin"}, "password": {"gophish"}, "csrf_token": {token}, diff --git a/controllers/static.go b/controllers/static.go deleted file mode 100644 index 4cdcbd6d..00000000 --- a/controllers/static.go +++ /dev/null @@ -1,35 +0,0 @@ -package controllers - -import ( - "net/http" - "strings" -) - -// UnindexedFileSystem is an implementation of a standard http.FileSystem -// without the ability to list files in the directory. -// This implementation is largely inspired by -// https://www.alexedwards.net/blog/disable-http-fileserver-directory-listings -type UnindexedFileSystem struct { - fs http.FileSystem -} - -// Open returns a file from the static directory. If the requested path ends -// with a slash, there is a check for an index.html file. If none exists, then -// an error is returned. -func (ufs UnindexedFileSystem) Open(name string) (http.File, error) { - f, err := ufs.fs.Open(name) - if err != nil { - return nil, err - } - - s, err := f.Stat() - if s.IsDir() { - index := strings.TrimSuffix(name, "/") + "/index.html" - indexFile, err := ufs.fs.Open(index) - if err != nil { - return nil, err - } - return indexFile, nil - } - return f, nil -} diff --git a/controllers/static_test.go b/controllers/static_test.go deleted file mode 100644 index 0f36464f..00000000 --- a/controllers/static_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package controllers - -import ( - "bytes" - "fmt" - "io/ioutil" - "net/http" - "os" - "path/filepath" -) - -var fileContent = []byte("Hello world") - -func mustRemoveAll(dir string) { - err := os.RemoveAll(dir) - if err != nil { - panic(err) - } -} - -func createTestFile(dir, filename string) error { - return ioutil.WriteFile(filepath.Join(dir, filename), fileContent, 0644) -} - -func (s *ControllersSuite) TestGetStaticFile() { - dir, err := ioutil.TempDir("static/endpoint", "test-") - tempFolder := filepath.Base(dir) - - s.Nil(err) - defer mustRemoveAll(dir) - - err = createTestFile(dir, "foo.txt") - s.Nil(nil, err) - - resp, err := http.Get(fmt.Sprintf("%s/static/%s/foo.txt", ps.URL, tempFolder)) - s.Nil(err) - - defer resp.Body.Close() - got, err := ioutil.ReadAll(resp.Body) - s.Nil(err) - - s.Equal(bytes.Compare(fileContent, got), 0, fmt.Sprintf("Got %s", got)) -} - -func (s *ControllersSuite) TestStaticFileListing() { - dir, err := ioutil.TempDir("static/endpoint", "test-") - tempFolder := filepath.Base(dir) - - s.Nil(err) - defer mustRemoveAll(dir) - - err = createTestFile(dir, "foo.txt") - s.Nil(nil, err) - - resp, err := http.Get(fmt.Sprintf("%s/static/%s/", ps.URL, tempFolder)) - s.Nil(err) - - defer resp.Body.Close() - s.Nil(err) - s.Equal(resp.StatusCode, http.StatusNotFound) -} - -func (s *ControllersSuite) TestStaticIndex() { - dir, err := ioutil.TempDir("static/endpoint", "test-") - tempFolder := filepath.Base(dir) - - s.Nil(err) - defer mustRemoveAll(dir) - - err = createTestFile(dir, "index.html") - s.Nil(nil, err) - - resp, err := http.Get(fmt.Sprintf("%s/static/%s/", ps.URL, tempFolder)) - s.Nil(err) - - defer resp.Body.Close() - got, err := ioutil.ReadAll(resp.Body) - s.Nil(err) - - s.Equal(bytes.Compare(fileContent, got), 0, fmt.Sprintf("Got %s", got)) -} diff --git a/gophish.go b/gophish.go index 514323ab..6c6fbd9f 100644 --- a/gophish.go +++ b/gophish.go @@ -26,24 +26,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ import ( - "compress/gzip" - "context" "io/ioutil" - "net/http" "os" - "sync" + "os/signal" "gopkg.in/alecthomas/kingpin.v2" - "github.com/NYTimes/gziphandler" "github.com/gophish/gophish/auth" "github.com/gophish/gophish/config" "github.com/gophish/gophish/controllers" log "github.com/gophish/gophish/logger" - "github.com/gophish/gophish/mailer" "github.com/gophish/gophish/models" - "github.com/gophish/gophish/util" - "github.com/gorilla/handlers" ) var ( @@ -65,31 +58,25 @@ func main() { kingpin.Parse() // Load the config - err = config.LoadConfig(*configPath) + conf, err := config.LoadConfig(*configPath) // Just warn if a contact address hasn't been configured if err != nil { log.Fatal(err) } - if config.Conf.ContactAddress == "" { + if conf.ContactAddress == "" { log.Warnf("No contact address has been configured.") log.Warnf("Please consider adding a contact_address entry in your config.json") } config.Version = string(version) - err = log.Setup() + err = log.Setup(conf) if err != nil { log.Fatal(err) } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // Provide the option to disable the built-in mailer - if !*disableMailer { - go mailer.Mailer.Start(ctx) - } // Setup the global variables and settings - err = models.Setup() + err = models.Setup(conf) if err != nil { log.Fatal(err) } @@ -99,39 +86,27 @@ func main() { if err != nil { log.Fatal(err) } - wg := &sync.WaitGroup{} - wg.Add(1) - // Start the web servers - go func() { - defer wg.Done() - gzipWrapper, _ := gziphandler.NewGzipLevelHandler(gzip.BestCompression) - adminHandler := gzipWrapper(controllers.CreateAdminRouter()) - auth.Store.Options.Secure = config.Conf.AdminConf.UseTLS - if config.Conf.AdminConf.UseTLS { // use TLS for Admin web server if available - err := util.CheckAndCreateSSL(config.Conf.AdminConf.CertPath, config.Conf.AdminConf.KeyPath) - if err != nil { - log.Fatal(err) - } - log.Infof("Starting admin server at https://%s", config.Conf.AdminConf.ListenURL) - log.Info(http.ListenAndServeTLS(config.Conf.AdminConf.ListenURL, config.Conf.AdminConf.CertPath, config.Conf.AdminConf.KeyPath, - handlers.CombinedLoggingHandler(log.Writer(), adminHandler))) - } else { - log.Infof("Starting admin server at http://%s", config.Conf.AdminConf.ListenURL) - log.Info(http.ListenAndServe(config.Conf.AdminConf.ListenURL, handlers.CombinedLoggingHandler(os.Stdout, adminHandler))) - } - }() - wg.Add(1) - go func() { - defer wg.Done() - phishHandler := gziphandler.GzipHandler(controllers.CreatePhishingRouter()) - if config.Conf.PhishConf.UseTLS { // use TLS for Phish web server if available - log.Infof("Starting phishing server at https://%s", config.Conf.PhishConf.ListenURL) - log.Info(http.ListenAndServeTLS(config.Conf.PhishConf.ListenURL, config.Conf.PhishConf.CertPath, config.Conf.PhishConf.KeyPath, - handlers.CombinedLoggingHandler(log.Writer(), phishHandler))) - } else { - log.Infof("Starting phishing server at http://%s", config.Conf.PhishConf.ListenURL) - log.Fatal(http.ListenAndServe(config.Conf.PhishConf.ListenURL, handlers.CombinedLoggingHandler(os.Stdout, phishHandler))) - } - }() - wg.Wait() + + // Create our servers + adminOptions := []controllers.AdminServerOption{} + if *disableMailer { + adminOptions = append(adminOptions, controllers.WithWorker(nil)) + } + adminConfig := conf.AdminConf + adminServer := controllers.NewAdminServer(adminConfig, adminOptions...) + auth.Store.Options.Secure = adminConfig.UseTLS + + phishConfig := conf.PhishConf + phishServer := controllers.NewPhishingServer(phishConfig) + + go adminServer.Start() + go phishServer.Start() + + // Handle graceful shutdown + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + <-c + log.Info("CTRL+C Received... Gracefully shutting down servers") + adminServer.Shutdown() + phishServer.Shutdown() } diff --git a/logger/logger.go b/logger/logger.go index feb7a167..1e107258 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -18,10 +18,10 @@ func init() { } // Setup configures the logger based on options in the config.json. -func Setup() error { +func Setup(conf *config.Config) error { Logger.SetLevel(logrus.InfoLevel) // Set up logging to a file if specified in the config - logFile := config.Conf.Logging.Filename + logFile := conf.Logging.Filename if logFile != "" { f, err := os.OpenFile(logFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) if err != nil { diff --git a/mailer/mailer.go b/mailer/mailer.go index 4a2a7410..99a3f38a 100644 --- a/mailer/mailer.go +++ b/mailer/mailer.go @@ -29,6 +29,13 @@ func (e *ErrMaxConnectAttempts) Error() string { return errString } +// Mailer is an interface that defines an object used to queue and +// send mailer.Mail instances. +type Mailer interface { + Start(ctx context.Context) + Queue([]Mail) +} + // Sender exposes the common operations required for sending email. type Sender interface { Send(from string, to []string, msg io.WriterTo) error @@ -50,27 +57,18 @@ type Mail interface { GetDialer() (Dialer, error) } -// Mailer is a global instance of the mailer that can -// be used in applications. It is the responsibility of the application -// to call Mailer.Start() -var Mailer *MailWorker - -func init() { - Mailer = NewMailWorker() -} - // MailWorker is the worker that receives slices of emails // on a channel to send. It's assumed that every slice of emails received is meant // to be sent to the same server. type MailWorker struct { - Queue chan []Mail + queue chan []Mail } // NewMailWorker returns an instance of MailWorker with the mail queue // initialized. func NewMailWorker() *MailWorker { return &MailWorker{ - Queue: make(chan []Mail), + queue: make(chan []Mail), } } @@ -81,7 +79,7 @@ func (mw *MailWorker) Start(ctx context.Context) { select { case <-ctx.Done(): return - case ms := <-mw.Queue: + case ms := <-mw.queue: go func(ctx context.Context, ms []Mail) { dialer, err := ms[0].GetDialer() if err != nil { @@ -94,6 +92,11 @@ func (mw *MailWorker) Start(ctx context.Context) { } } +// Queue sends the provided mail to the internal queue for processing. +func (mw *MailWorker) Queue(ms []Mail) { + mw.queue <- ms +} + // errorMail is a helper to handle erroring out a slice of Mail instances // in the case that an unrecoverable error occurs. func errorMail(err error, ms []Mail) { diff --git a/mailer/mailer_test.go b/mailer/mailer_test.go index a1573c39..a8c0450f 100644 --- a/mailer/mailer_test.go +++ b/mailer/mailer_test.go @@ -88,7 +88,7 @@ func (ms *MailerSuite) TestMailWorkerStart() { messages := generateMessages(dialer) // Send the campaign - mw.Queue <- messages + mw.Queue(messages) got := []*mockMessage{} @@ -129,7 +129,7 @@ func (ms *MailerSuite) TestBackoff() { messages := generateMessages(dialer) // Send the campaign - mw.Queue <- messages + mw.Queue(messages) got := []*mockMessage{} @@ -183,7 +183,7 @@ func (ms *MailerSuite) TestPermError() { messages := generateMessages(dialer) // Send the campaign - mw.Queue <- messages + mw.Queue(messages) got := []*mockMessage{} @@ -242,7 +242,7 @@ func (ms *MailerSuite) TestUnknownError() { messages := generateMessages(dialer) // Send the campaign - mw.Queue <- messages + mw.Queue(messages) got := []*mockMessage{} diff --git a/middleware/middleware.go b/middleware/middleware.go index d0e0187b..25e0c4b8 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -60,8 +60,10 @@ func GetContext(handler http.Handler) http.HandlerFunc { } } -func RequireAPIKey(handler http.Handler) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +// RequireAPIKey ensures that a valid API key is set as either the api_key GET +// parameter, or a Bearer token. +func RequireAPIKey(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") if r.Method == "OPTIONS" { w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") @@ -81,18 +83,18 @@ func RequireAPIKey(handler http.Handler) http.HandlerFunc { } } if ak == "" { - JSONError(w, 400, "API Key not set") + JSONError(w, http.StatusUnauthorized, "API Key not set") return } u, err := models.GetUserByAPIKey(ak) if err != nil { - JSONError(w, 400, "Invalid API Key") + JSONError(w, http.StatusUnauthorized, "Invalid API Key") return } r = ctx.Set(r, "user_id", u.Id) r = ctx.Set(r, "api_key", ak) handler.ServeHTTP(w, r) - } + }) } // RequireLogin is a simple middleware which checks to see if the user is currently logged in. @@ -104,7 +106,7 @@ func RequireLogin(handler http.Handler) http.HandlerFunc { } else { q := r.URL.Query() q.Set("next", r.URL.Path) - http.Redirect(w, r, fmt.Sprintf("/login?%s", q.Encode()), 302) + http.Redirect(w, r, fmt.Sprintf("/login?%s", q.Encode()), http.StatusTemporaryRedirect) } } } diff --git a/models/campaign.go b/models/campaign.go index 2841e564..9017f787 100644 --- a/models/campaign.go +++ b/models/campaign.go @@ -165,7 +165,7 @@ func (c *Campaign) AddEvent(e *Event) error { // an error is returned. Otherwise, the attribute name is set to [Deleted], // indicating the user deleted the attribute (template, smtp, etc.) func (c *Campaign) getDetails() error { - err = db.Model(c).Related(&c.Results).Error + err := db.Model(c).Related(&c.Results).Error if err != nil { log.Warnf("%s: results not found for campaign", err) return err @@ -402,7 +402,8 @@ func GetQueuedCampaigns(t time.Time) ([]Campaign, error) { // PostCampaign inserts a campaign and all associated records into the database. func PostCampaign(c *Campaign, uid int64) error { - if err := c.Validate(); err != nil { + err := c.Validate() + if err != nil { return err } // Fill in the details @@ -514,14 +515,16 @@ func PostCampaign(c *Campaign, uid int64) error { Reported: false, ModifiedDate: c.CreatedDate, } - if r.SendDate.Before(c.CreatedDate) || r.SendDate.Equal(c.CreatedDate) { - r.Status = STATUS_SENDING - } err = r.GenerateId() if err != nil { log.Error(err) continue } + processing := false + if r.SendDate.Before(c.CreatedDate) || r.SendDate.Equal(c.CreatedDate) { + r.Status = STATUS_SENDING + processing = true + } err = db.Save(r).Error if err != nil { log.WithFields(logrus.Fields{ @@ -530,7 +533,14 @@ func PostCampaign(c *Campaign, uid int64) error { } c.Results = append(c.Results, *r) log.Infof("Creating maillog for %s to send at %s\n", r.Email, sendDate) - err = GenerateMailLog(c, r, sendDate) + m := &MailLog{ + UserId: c.UserId, + CampaignId: c.Id, + RId: r.RId, + SendDate: sendDate, + Processing: processing, + } + err = db.Save(m).Error if err != nil { log.Error(err) continue diff --git a/models/campaign_test.go b/models/campaign_test.go index 7cb3e2d3..b0f36d2b 100644 --- a/models/campaign_test.go +++ b/models/campaign_test.go @@ -79,3 +79,29 @@ func (s *ModelsSuite) TestCampaignDateValidation(c *check.C) { err = campaign.Validate() c.Assert(err, check.Equals, ErrInvalidSendByDate) } + +func (s *ModelsSuite) TestLaunchCampaignMaillogStatus(c *check.C) { + // For the first test, ensure that campaigns created with the zero date + // (and therefore are set to launch immediately) have maillogs that are + // locked to prevent race conditions. + campaign := s.createCampaign(c) + ms, err := GetMailLogsByCampaign(campaign.Id) + c.Assert(err, check.Equals, nil) + + for _, m := range ms { + c.Assert(m.Processing, check.Equals, true) + } + + // Next, verify that campaigns scheduled in the future do not lock the + // maillogs so that they can be picked up by the background worker. + campaign = s.createCampaignDependencies(c) + campaign.Name = "New Campaign" + campaign.LaunchDate = time.Now().Add(1 * time.Hour) + c.Assert(PostCampaign(&campaign, campaign.UserId), check.Equals, nil) + ms, err = GetMailLogsByCampaign(campaign.Id) + c.Assert(err, check.Equals, nil) + + for _, m := range ms { + c.Assert(m.Processing, check.Equals, false) + } +} diff --git a/models/email_request.go b/models/email_request.go index 7fb60fc7..e4744c93 100644 --- a/models/email_request.go +++ b/models/email_request.go @@ -122,8 +122,8 @@ func (s *EmailRequest) Generate(msg *gomail.Message) error { // Add the transparency headers msg.SetHeader("X-Mailer", config.ServerName) - if config.Conf.ContactAddress != "" { - msg.SetHeader("X-Gophish-Contact", config.Conf.ContactAddress) + if conf.ContactAddress != "" { + msg.SetHeader("X-Gophish-Contact", conf.ContactAddress) } // Parse the customHeader templates diff --git a/models/email_request_test.go b/models/email_request_test.go index ef3a3dae..a954d988 100644 --- a/models/email_request_test.go +++ b/models/email_request_test.go @@ -26,7 +26,7 @@ func (s *ModelsSuite) TestEmailRequestBackoff(ch *check.C) { } expected := errors.New("Temporary Error") go func() { - err = req.Backoff(expected) + err := req.Backoff(expected) ch.Assert(err, check.Equals, nil) }() ch.Assert(<-req.ErrorChan, check.Equals, expected) @@ -38,7 +38,7 @@ func (s *ModelsSuite) TestEmailRequestError(ch *check.C) { } expected := errors.New("Temporary Error") go func() { - err = req.Error(expected) + err := req.Error(expected) ch.Assert(err, check.Equals, nil) }() ch.Assert(<-req.ErrorChan, check.Equals, expected) @@ -49,7 +49,7 @@ func (s *ModelsSuite) TestEmailRequestSuccess(ch *check.C) { ErrorChan: make(chan error), } go func() { - err = req.Success() + err := req.Success() ch.Assert(err, check.Equals, nil) }() ch.Assert(<-req.ErrorChan, check.Equals, nil) @@ -76,14 +76,14 @@ func (s *ModelsSuite) TestEmailRequestGenerate(ch *check.C) { FromAddress: smtp.FromAddress, } - config.Conf.ContactAddress = "test@test.com" + s.config.ContactAddress = "test@test.com" expectedHeaders := map[string]string{ "X-Mailer": config.ServerName, - "X-Gophish-Contact": config.Conf.ContactAddress, + "X-Gophish-Contact": s.config.ContactAddress, } msg := gomail.NewMessage() - err = req.Generate(msg) + err := req.Generate(msg) ch.Assert(err, check.Equals, nil) expected := &email.Email{ @@ -130,7 +130,7 @@ func (s *ModelsSuite) TestEmailRequestURLTemplating(ch *check.C) { } msg := gomail.NewMessage() - err = req.Generate(msg) + err := req.Generate(msg) ch.Assert(err, check.Equals, nil) expectedURL := fmt.Sprintf("http://127.0.0.1/%s?%s=%s", req.Email, RecipientParameter, req.RId) @@ -167,7 +167,7 @@ func (s *ModelsSuite) TestEmailRequestGenerateEmptySubject(ch *check.C) { } msg := gomail.NewMessage() - err = req.Generate(msg) + err := req.Generate(msg) ch.Assert(err, check.Equals, nil) expected := &email.Email{ diff --git a/models/group.go b/models/group.go index a7f0248a..198a0f6e 100644 --- a/models/group.go +++ b/models/group.go @@ -196,7 +196,7 @@ func PostGroup(g *Group) error { return err } // Insert the group into the DB - err = db.Save(g).Error + err := db.Save(g).Error if err != nil { log.Error(err) return err @@ -214,7 +214,7 @@ func PutGroup(g *Group) error { } // Fetch group's existing targets from database. ts := []Target{} - ts, err = GetTargets(g.Id) + ts, err := GetTargets(g.Id) if err != nil { log.WithFields(logrus.Fields{ "group_id": g.Id, @@ -234,7 +234,7 @@ func PutGroup(g *Group) error { } // If the target does not exist in the group any longer, we delete it if !tExists { - err = db.Where("group_id=? and target_id=?", g.Id, t.Id).Delete(&GroupTarget{}).Error + err := db.Where("group_id=? and target_id=?", g.Id, t.Id).Delete(&GroupTarget{}).Error if err != nil { log.WithFields(logrus.Fields{ "email": t.Email, @@ -286,14 +286,14 @@ func DeleteGroup(g *Group) error { } func insertTargetIntoGroup(t Target, gid int64) error { - if _, err = mail.ParseAddress(t.Email); err != nil { + if _, err := mail.ParseAddress(t.Email); err != nil { log.WithFields(logrus.Fields{ "email": t.Email, }).Error("Invalid email") return err } trans := db.Begin() - err = trans.Where(t).FirstOrCreate(&t).Error + err := trans.Where(t).FirstOrCreate(&t).Error if err != nil { log.WithFields(logrus.Fields{ "email": t.Email, diff --git a/models/maillog.go b/models/maillog.go index 3b56a163..bbcc7a25 100644 --- a/models/maillog.go +++ b/models/maillog.go @@ -45,8 +45,7 @@ func GenerateMailLog(c *Campaign, r *Result, sendDate time.Time) error { RId: r.RId, SendDate: sendDate, } - err = db.Save(m).Error - return err + return db.Save(m).Error } // Backoff sets the MailLog SendDate to be the next entry in an exponential @@ -160,8 +159,8 @@ func (m *MailLog) Generate(msg *gomail.Message) error { // Add the transparency headers msg.SetHeader("X-Mailer", config.ServerName) - if config.Conf.ContactAddress != "" { - msg.SetHeader("X-Gophish-Contact", config.Conf.ContactAddress) + if conf.ContactAddress != "" { + msg.SetHeader("X-Gophish-Contact", conf.ContactAddress) } // Parse the customHeader templates for _, header := range c.SMTP.Headers { @@ -260,6 +259,5 @@ func LockMailLogs(ms []*MailLog, lock bool) error { // in the database. This is intended to be called when Gophish is started // so that any previously locked maillogs can resume processing. func UnlockAllMailLogs() error { - err = db.Model(&MailLog{}).Update("processing", false).Error - return err + return db.Model(&MailLog{}).Update("processing", false).Error } diff --git a/models/maillog_test.go b/models/maillog_test.go index c3c73a9c..5011030f 100644 --- a/models/maillog_test.go +++ b/models/maillog_test.go @@ -37,7 +37,13 @@ func (s *ModelsSuite) emailFromFirstMailLog(campaign Campaign, ch *check.C) *ema func (s *ModelsSuite) TestGetQueuedMailLogs(ch *check.C) { campaign := s.createCampaign(ch) - ms, err := GetQueuedMailLogs(campaign.LaunchDate) + // By default, for campaigns with no launch date, the maillogs are set as + // being processed. We need to unlock them first. + ms, err := GetMailLogsByCampaign(campaign.Id) + ch.Assert(err, check.Equals, nil) + err = LockMailLogs(ms, false) + ch.Assert(err, check.Equals, nil) + ms, err = GetQueuedMailLogs(campaign.LaunchDate) ch.Assert(err, check.Equals, nil) got := make(map[string]*MailLog) for _, m := range ms { @@ -222,10 +228,10 @@ func (s *ModelsSuite) TestMailLogGenerate(ch *check.C) { } func (s *ModelsSuite) TestMailLogGenerateTransparencyHeaders(ch *check.C) { - config.Conf.ContactAddress = "test@test.com" + s.config.ContactAddress = "test@test.com" expectedHeaders := map[string]string{ "X-Mailer": config.ServerName, - "X-Gophish-Contact": config.Conf.ContactAddress, + "X-Gophish-Contact": s.config.ContactAddress, } campaign := s.createCampaign(ch) got := s.emailFromFirstMailLog(campaign, ch) @@ -264,12 +270,6 @@ func (s *ModelsSuite) TestUnlockAllMailLogs(ch *check.C) { campaign := s.createCampaign(ch) ms, err := GetMailLogsByCampaign(campaign.Id) ch.Assert(err, check.Equals, nil) - for _, m := range ms { - ch.Assert(m.Processing, check.Equals, false) - } - err = LockMailLogs(ms, true) - ms, err = GetMailLogsByCampaign(campaign.Id) - ch.Assert(err, check.Equals, nil) for _, m := range ms { ch.Assert(m.Processing, check.Equals, true) } diff --git a/models/models.go b/models/models.go index eb7bfe2a..52f930a5 100644 --- a/models/models.go +++ b/models/models.go @@ -15,7 +15,7 @@ import ( ) var db *gorm.DB -var err error +var conf *config.Config const ( CAMPAIGN_IN_PROGRESS string = "In progress" @@ -78,12 +78,14 @@ func chooseDBDriver(name, openStr string) goose.DBDriver { // Setup initializes the Conn object // It also populates the Gophish Config object -func Setup() error { +func Setup(c *config.Config) error { + // Setup the package-scoped config + conf = c // Setup the goose configuration migrateConf := &goose.DBConf{ - MigrationsDir: config.Conf.MigrationsPath, + MigrationsDir: conf.MigrationsPath, Env: "production", - Driver: chooseDBDriver(config.Conf.DBName, config.Conf.DBPath), + Driver: chooseDBDriver(conf.DBName, conf.DBPath), } // Get the latest possible migration latest, err := goose.GetMostRecentDBVersion(migrateConf.MigrationsDir) @@ -92,7 +94,7 @@ func Setup() error { return err } // Open our database connection - db, err = gorm.Open(config.Conf.DBName, config.Conf.DBPath) + db, err = gorm.Open(conf.DBName, conf.DBPath) db.LogMode(false) db.SetLogger(log.Logger) db.DB().SetMaxOpenConns(1) diff --git a/models/models_test.go b/models/models_test.go index 8374911c..f11e8bf2 100644 --- a/models/models_test.go +++ b/models/models_test.go @@ -10,15 +10,20 @@ import ( // Hook up gocheck into the "go test" runner. func Test(t *testing.T) { check.TestingT(t) } -type ModelsSuite struct{} +type ModelsSuite struct { + config *config.Config +} var _ = check.Suite(&ModelsSuite{}) func (s *ModelsSuite) SetUpSuite(c *check.C) { - config.Conf.DBName = "sqlite3" - config.Conf.DBPath = ":memory:" - config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/" - err := Setup() + conf := &config.Config{ + DBName: "sqlite3", + DBPath: ":memory:", + MigrationsPath: "../db/db_sqlite3/migrations/", + } + s.config = conf + err := Setup(conf) if err != nil { c.Fatalf("Failed creating database: %v", err) } diff --git a/models/page.go b/models/page.go index 73aca7d5..30da2179 100644 --- a/models/page.go +++ b/models/page.go @@ -148,7 +148,7 @@ func PutPage(p *Page) error { // DeletePage deletes an existing page in the database. // An error is returned if a page with the given user id and page id is not found. func DeletePage(id int64, uid int64) error { - err = db.Where("user_id=?", uid).Delete(Page{Id: id}).Error + err := db.Where("user_id=?", uid).Delete(Page{Id: id}).Error if err != nil { log.Error(err) } diff --git a/models/smtp.go b/models/smtp.go index ea7d4c23..f6dca63e 100644 --- a/models/smtp.go +++ b/models/smtp.go @@ -228,7 +228,7 @@ func PutSMTP(s *SMTP) error { // An error is returned if a SMTP with the given user id and SMTP id is not found. func DeleteSMTP(id int64, uid int64) error { // Delete all custom headers - err = db.Where("smtp_id=?", id).Delete(&Header{}).Error + err := db.Where("smtp_id=?", id).Delete(&Header{}).Error if err != nil { log.Error(err) return err diff --git a/models/smtp_test.go b/models/smtp_test.go index ecf17f58..e6552a9e 100644 --- a/models/smtp_test.go +++ b/models/smtp_test.go @@ -15,7 +15,7 @@ func (s *ModelsSuite) TestPostSMTP(c *check.C) { FromAddress: "Foo Bar ", UserId: 1, } - err = PostSMTP(&smtp) + err := PostSMTP(&smtp) c.Assert(err, check.Equals, nil) ss, err := GetSMTPs(1) c.Assert(err, check.Equals, nil) @@ -28,7 +28,7 @@ func (s *ModelsSuite) TestPostSMTPNoHost(c *check.C) { FromAddress: "Foo Bar ", UserId: 1, } - err = PostSMTP(&smtp) + err := PostSMTP(&smtp) c.Assert(err, check.Equals, ErrHostNotSpecified) } @@ -38,7 +38,7 @@ func (s *ModelsSuite) TestPostSMTPNoFrom(c *check.C) { UserId: 1, Host: "1.1.1.1:25", } - err = PostSMTP(&smtp) + err := PostSMTP(&smtp) c.Assert(err, check.Equals, ErrFromAddressNotSpecified) } @@ -53,7 +53,7 @@ func (s *ModelsSuite) TestPostSMTPValidHeader(c *check.C) { Header{Key: "X-Mailer", Value: "gophish"}, }, } - err = PostSMTP(&smtp) + err := PostSMTP(&smtp) c.Assert(err, check.Equals, nil) ss, err := GetSMTPs(1) c.Assert(err, check.Equals, nil) diff --git a/models/template.go b/models/template.go index a4e071ca..8f0637e9 100644 --- a/models/template.go +++ b/models/template.go @@ -34,10 +34,10 @@ func (t *Template) Validate() error { case t.Text == "" && t.HTML == "": return ErrTemplateMissingParameter } - if err = ValidateTemplate(t.HTML); err != nil { + if err := ValidateTemplate(t.HTML); err != nil { return err } - if err = ValidateTemplate(t.Text); err != nil { + if err := ValidateTemplate(t.Text); err != nil { return err } return nil @@ -113,7 +113,7 @@ func PostTemplate(t *Template) error { if err := t.Validate(); err != nil { return err } - err = db.Save(t).Error + err := db.Save(t).Error if err != nil { log.Error(err) return err @@ -138,7 +138,7 @@ func PutTemplate(t *Template) error { return err } // Delete all attachments, and replace with new ones - err = db.Where("template_id=?", t.Id).Delete(&Attachment{}).Error + err := db.Where("template_id=?", t.Id).Delete(&Attachment{}).Error if err != nil && err != gorm.ErrRecordNotFound { log.Error(err) return err @@ -146,7 +146,7 @@ func PutTemplate(t *Template) error { if err == gorm.ErrRecordNotFound { err = nil } - for i, _ := range t.Attachments { + for i := range t.Attachments { t.Attachments[i].TemplateId = t.Id err := db.Save(&t.Attachments[i]).Error if err != nil { diff --git a/util/util_test.go b/util/util_test.go index 1274071c..874e3cbf 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -18,10 +18,12 @@ type UtilSuite struct { } func (s *UtilSuite) SetupSuite() { - config.Conf.DBName = "sqlite3" - config.Conf.DBPath = ":memory:" - config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/" - err := models.Setup() + conf := &config.Config{ + DBName: "sqlite3", + DBPath: ":memory:", + MigrationsPath: "../db/db_sqlite3/migrations/", + } + err := models.Setup(conf) if err != nil { s.T().Fatalf("Failed creating database: %v", err) } diff --git a/worker/worker.go b/worker/worker.go index eaa6c6cc..7161ec4b 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -1,6 +1,7 @@ package worker import ( + "context" "time" log "github.com/gophish/gophish/logger" @@ -9,18 +10,46 @@ import ( "github.com/sirupsen/logrus" ) -// Worker is the background worker that handles watching for new campaigns and sending emails appropriately. -type Worker struct{} +// Worker is an interface that defines the operations needed for a background worker +type Worker interface { + Start() + LaunchCampaign(c models.Campaign) + SendTestEmail(s *models.EmailRequest) error +} + +// DefaultWorker is the background worker that handles watching for new campaigns and sending emails appropriately. +type DefaultWorker struct { + mailer mailer.Mailer +} // New creates a new worker object to handle the creation of campaigns -func New() *Worker { - return &Worker{} +func New(options ...func(Worker) error) (Worker, error) { + defaultMailer := mailer.NewMailWorker() + w := &DefaultWorker{ + mailer: defaultMailer, + } + for _, opt := range options { + if err := opt(w); err != nil { + return nil, err + } + } + return w, nil +} + +// WithMailer sets the mailer for a given worker. +// By default, workers use a standard, default mailworker. +func WithMailer(m mailer.Mailer) func(*DefaultWorker) error { + return func(w *DefaultWorker) error { + w.mailer = m + return nil + } } // Start launches the worker to poll the database every minute for any pending maillogs // that need to be processed. -func (w *Worker) Start() { +func (w *DefaultWorker) Start() { log.Info("Background Worker Started Successfully - Waiting for Campaigns") + go w.mailer.Start(context.Background()) for t := range time.Tick(1 * time.Minute) { ms, err := models.GetQueuedMailLogs(t.UTC()) if err != nil { @@ -62,14 +91,14 @@ func (w *Worker) Start() { log.WithFields(logrus.Fields{ "num_emails": len(msc), }).Info("Sending emails to mailer for processing") - mailer.Mailer.Queue <- msc + w.mailer.Queue(msc) }(cid, msc) } } } // LaunchCampaign starts a campaign -func (w *Worker) LaunchCampaign(c models.Campaign) { +func (w *DefaultWorker) LaunchCampaign(c models.Campaign) { ms, err := models.GetMailLogsByCampaign(c.Id) if err != nil { log.Error(err) @@ -89,14 +118,14 @@ func (w *Worker) LaunchCampaign(c models.Campaign) { } mailEntries = append(mailEntries, m) } - mailer.Mailer.Queue <- mailEntries + w.mailer.Queue(mailEntries) } // SendTestEmail sends a test email -func (w *Worker) SendTestEmail(s *models.EmailRequest) error { +func (w *DefaultWorker) SendTestEmail(s *models.EmailRequest) error { go func() { ms := []mailer.Mail{s} - mailer.Mailer.Queue <- ms + w.mailer.Queue(ms) }() return <-s.ErrorChan } diff --git a/worker/worker_test.go b/worker/worker_test.go index 9ad6383d..a0252604 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -9,17 +9,20 @@ import ( // WorkerSuite is a suite of tests to cover API related functions type WorkerSuite struct { suite.Suite - ApiKey string + config *config.Config } func (s *WorkerSuite) SetupSuite() { - config.Conf.DBName = "sqlite3" - config.Conf.DBPath = ":memory:" - config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/" - err := models.Setup() + conf := &config.Config{ + DBName: "sqlite3", + DBPath: ":memory:", + MigrationsPath: "../db/db_sqlite3/migrations/", + } + err := models.Setup(conf) if err != nil { s.T().Fatalf("Failed creating database: %v", err) } + s.config = conf s.Nil(err) } @@ -31,7 +34,7 @@ func (s *WorkerSuite) TearDownTest() { } func (s *WorkerSuite) SetupTest() { - config.Conf.TestFlag = true + s.config.TestFlag = true // Add a group group := models.Group{Name: "Test Group"} group.Targets = []models.Target{ @@ -73,7 +76,3 @@ func (s *WorkerSuite) SetupTest() { models.PostCampaign(&c, c.UserId) c.UpdateStatus(models.CAMPAIGN_EMAILS_SENT) } - -func (s *WorkerSuite) TestMailSendSuccess() { - // TODO -}