diff --git a/config/config.go b/config/config.go
index c1589837..ba81e2f1 100644
--- a/config/config.go
+++ b/config/config.go
@@ -38,9 +38,6 @@ type Config struct {
Logging LoggingConfig `json:"logging"`
}
-// Conf contains the initialized configuration struct
-var Conf Config
-
// Version contains the current gophish version
var Version = ""
@@ -48,19 +45,20 @@ var Version = ""
const ServerName = "gophish"
// LoadConfig loads the configuration from the specified filepath
-func LoadConfig(filepath string) error {
+func LoadConfig(filepath string) (*Config, error) {
// Get the config file
configFile, err := ioutil.ReadFile(filepath)
if err != nil {
- return err
+ return nil, err
}
- err = json.Unmarshal(configFile, &Conf)
+ config := &Config{}
+ err = json.Unmarshal(configFile, config)
if err != nil {
- return err
+ return nil, err
}
// Choosing the migrations directory based on the database used.
- Conf.MigrationsPath = Conf.MigrationsPath + Conf.DBName
+ config.MigrationsPath = config.MigrationsPath + config.DBName
// Explicitly set the TestFlag to false to prevent config.json overrides
- Conf.TestFlag = false
- return nil
+ config.TestFlag = false
+ return config, nil
}
diff --git a/config/config_test.go b/config/config_test.go
index 7e16b81a..26f4fd4f 100644
--- a/config/config_test.go
+++ b/config/config_test.go
@@ -48,18 +48,18 @@ func (s *ConfigSuite) TestLoadConfig() {
_, err := s.ConfigFile.Write(validConfig)
s.Nil(err)
// Load the valid config
- err = LoadConfig(s.ConfigFile.Name())
+ conf, err := LoadConfig(s.ConfigFile.Name())
s.Nil(err)
- expectedConfig := Config{}
+ expectedConfig := &Config{}
err = json.Unmarshal(validConfig, &expectedConfig)
s.Nil(err)
expectedConfig.MigrationsPath = expectedConfig.MigrationsPath + expectedConfig.DBName
expectedConfig.TestFlag = false
- s.Equal(expectedConfig, Conf)
+ s.Equal(expectedConfig, conf)
// Load an invalid config
- err = LoadConfig("bogusfile")
+ conf, err = LoadConfig("bogusfile")
s.NotNil(err)
}
diff --git a/controllers/api.go b/controllers/api.go
index 52abfce0..bd6d3938 100644
--- a/controllers/api.go
+++ b/controllers/api.go
@@ -17,23 +17,14 @@ import (
log "github.com/gophish/gophish/logger"
"github.com/gophish/gophish/models"
"github.com/gophish/gophish/util"
- "github.com/gophish/gophish/worker"
"github.com/gorilla/mux"
"github.com/jinzhu/gorm"
"github.com/jordan-wright/email"
"github.com/sirupsen/logrus"
)
-// Worker is the worker that processes phishing events and updates campaigns.
-var Worker *worker.Worker
-
-func init() {
- Worker = worker.New()
- go Worker.Start()
-}
-
// API (/api/reset) resets a user's API key
-func API_Reset(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Reset(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == "POST":
u := ctx.Get(r, "user").(models.User)
@@ -49,7 +40,7 @@ func API_Reset(w http.ResponseWriter, r *http.Request) {
// API_Campaigns returns a list of campaigns if requested via GET.
// If requested via POST, API_Campaigns creates a new campaign and returns a reference to it.
-func API_Campaigns(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Campaigns(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == "GET":
cs, err := models.GetCampaigns(ctx.Get(r, "user_id").(int64))
@@ -74,14 +65,14 @@ func API_Campaigns(w http.ResponseWriter, r *http.Request) {
// If the campaign is scheduled to launch immediately, send it to the worker.
// Otherwise, the worker will pick it up at the scheduled time
if c.Status == models.CAMPAIGN_IN_PROGRESS {
- go Worker.LaunchCampaign(c)
+ go as.worker.LaunchCampaign(c)
}
JSONResponse(w, c, http.StatusCreated)
}
}
// API_Campaigns_Summary returns the summary for the current user's campaigns
-func API_Campaigns_Summary(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Campaigns_Summary(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == "GET":
cs, err := models.GetCampaignSummaries(ctx.Get(r, "user_id").(int64))
@@ -96,7 +87,7 @@ func API_Campaigns_Summary(w http.ResponseWriter, r *http.Request) {
// API_Campaigns_Id returns details about the requested campaign. If the campaign is not
// valid, API_Campaigns_Id returns null.
-func API_Campaigns_Id(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Campaigns_Id(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64)
c, err := models.GetCampaign(id, ctx.Get(r, "user_id").(int64))
@@ -120,7 +111,7 @@ func API_Campaigns_Id(w http.ResponseWriter, r *http.Request) {
// API_Campaigns_Id_Results returns just the results for a given campaign to
// significantly reduce the information returned.
-func API_Campaigns_Id_Results(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Campaigns_Id_Results(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64)
cr, err := models.GetCampaignResults(id, ctx.Get(r, "user_id").(int64))
@@ -136,7 +127,7 @@ func API_Campaigns_Id_Results(w http.ResponseWriter, r *http.Request) {
}
// API_Campaigns_Id_Summary returns just the summary for a given campaign.
-func API_Campaign_Id_Summary(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Campaign_Id_Summary(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64)
switch {
@@ -157,7 +148,7 @@ func API_Campaign_Id_Summary(w http.ResponseWriter, r *http.Request) {
// API_Campaigns_Id_Complete effectively "ends" a campaign.
// Future phishing emails clicked will return a simple "404" page.
-func API_Campaigns_Id_Complete(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Campaigns_Id_Complete(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64)
switch {
@@ -173,7 +164,7 @@ func API_Campaigns_Id_Complete(w http.ResponseWriter, r *http.Request) {
// API_Groups returns a list of groups if requested via GET.
// If requested via POST, API_Groups creates a new group and returns a reference to it.
-func API_Groups(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Groups(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == "GET":
gs, err := models.GetGroups(ctx.Get(r, "user_id").(int64))
@@ -208,7 +199,7 @@ func API_Groups(w http.ResponseWriter, r *http.Request) {
}
// API_Groups_Summary returns a summary of the groups owned by the current user.
-func API_Groups_Summary(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Groups_Summary(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == "GET":
gs, err := models.GetGroupSummaries(ctx.Get(r, "user_id").(int64))
@@ -223,7 +214,7 @@ func API_Groups_Summary(w http.ResponseWriter, r *http.Request) {
// API_Groups_Id returns details about the requested group.
// If the group is not valid, API_Groups_Id returns null.
-func API_Groups_Id(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Groups_Id(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64)
g, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64))
@@ -261,7 +252,7 @@ func API_Groups_Id(w http.ResponseWriter, r *http.Request) {
}
// API_Groups_Id_Summary returns a summary of the groups owned by the current user.
-func API_Groups_Id_Summary(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Groups_Id_Summary(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == "GET":
vars := mux.Vars(r)
@@ -276,7 +267,7 @@ func API_Groups_Id_Summary(w http.ResponseWriter, r *http.Request) {
}
// API_Templates handles the functionality for the /api/templates endpoint
-func API_Templates(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Templates(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == "GET":
ts, err := models.GetTemplates(ctx.Get(r, "user_id").(int64))
@@ -319,7 +310,7 @@ func API_Templates(w http.ResponseWriter, r *http.Request) {
}
// API_Templates_Id handles the functions for the /api/templates/:id endpoint
-func API_Templates_Id(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Templates_Id(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64)
t, err := models.GetTemplate(id, ctx.Get(r, "user_id").(int64))
@@ -359,7 +350,7 @@ func API_Templates_Id(w http.ResponseWriter, r *http.Request) {
}
// API_Pages handles requests for the /api/pages/ endpoint
-func API_Pages(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Pages(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == "GET":
ps, err := models.GetPages(ctx.Get(r, "user_id").(int64))
@@ -396,7 +387,7 @@ func API_Pages(w http.ResponseWriter, r *http.Request) {
// API_Pages_Id contains functions to handle the GET'ing, DELETE'ing, and PUT'ing
// of a Page object
-func API_Pages_Id(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Pages_Id(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64)
p, err := models.GetPage(id, ctx.Get(r, "user_id").(int64))
@@ -436,7 +427,7 @@ func API_Pages_Id(w http.ResponseWriter, r *http.Request) {
}
// API_SMTP handles requests for the /api/smtp/ endpoint
-func API_SMTP(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_SMTP(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == "GET":
ss, err := models.GetSMTPs(ctx.Get(r, "user_id").(int64))
@@ -473,7 +464,7 @@ func API_SMTP(w http.ResponseWriter, r *http.Request) {
// API_SMTP_Id contains functions to handle the GET'ing, DELETE'ing, and PUT'ing
// of a SMTP object
-func API_SMTP_Id(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_SMTP_Id(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64)
s, err := models.GetSMTP(id, ctx.Get(r, "user_id").(int64))
@@ -518,7 +509,7 @@ func API_SMTP_Id(w http.ResponseWriter, r *http.Request) {
}
// API_Import_Group imports a CSV of group members
-func API_Import_Group(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Import_Group(w http.ResponseWriter, r *http.Request) {
ts, err := util.ParseCSV(r)
if err != nil {
JSONResponse(w, models.Response{Success: false, Message: "Error parsing CSV"}, http.StatusInternalServerError)
@@ -530,7 +521,7 @@ func API_Import_Group(w http.ResponseWriter, r *http.Request) {
// API_Import_Email allows for the importing of email.
// Returns a Message object
-func API_Import_Email(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Import_Email(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
JSONResponse(w, models.Response{Success: false, Message: "Method not allowed"}, http.StatusBadRequest)
return
@@ -579,7 +570,7 @@ func API_Import_Email(w http.ResponseWriter, r *http.Request) {
// API_Import_Site allows for the importing of HTML from a website
// Without "include_resources" set, it will merely place a "base" tag
// so that all resources can be loaded relative to the given URL.
-func API_Import_Site(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Import_Site(w http.ResponseWriter, r *http.Request) {
cr := cloneRequest{}
if r.Method != "POST" {
JSONResponse(w, models.Response{Success: false, Message: "Method not allowed"}, http.StatusBadRequest)
@@ -637,7 +628,7 @@ func API_Import_Site(w http.ResponseWriter, r *http.Request) {
// API_Send_Test_Email sends a test email using the template name
// and Target given.
-func API_Send_Test_Email(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) API_Send_Test_Email(w http.ResponseWriter, r *http.Request) {
s := &models.EmailRequest{
ErrorChan: make(chan error),
UserId: ctx.Get(r, "user_id").(int64),
@@ -735,7 +726,7 @@ func API_Send_Test_Email(w http.ResponseWriter, r *http.Request) {
}
}
// Send the test email
- err = Worker.SendTestEmail(s)
+ err = as.worker.SendTestEmail(s)
if err != nil {
log.Error(err)
JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusInternalServerError)
diff --git a/controllers/api_test.go b/controllers/api_test.go
index 8e024ece..cc0684c7 100644
--- a/controllers/api_test.go
+++ b/controllers/api_test.go
@@ -11,41 +11,42 @@ import (
"github.com/gophish/gophish/config"
"github.com/gophish/gophish/models"
- "github.com/gorilla/handlers"
"github.com/stretchr/testify/suite"
)
// ControllersSuite is a suite of tests to cover API related functions
type ControllersSuite struct {
suite.Suite
- ApiKey string
+ ApiKey string
+ config *config.Config
+ adminServer *httptest.Server
+ phishServer *httptest.Server
}
-// as is the Admin Server for our API calls
-var as *httptest.Server = httptest.NewUnstartedServer(handlers.CombinedLoggingHandler(os.Stdout, CreateAdminRouter()))
-
-// ps is the Phishing Server
-var ps *httptest.Server = httptest.NewUnstartedServer(handlers.CombinedLoggingHandler(os.Stdout, CreatePhishingRouter()))
-
func (s *ControllersSuite) SetupSuite() {
- config.Conf.DBName = "sqlite3"
- config.Conf.DBPath = ":memory:"
- config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/"
- err := models.Setup()
+ conf := &config.Config{
+ DBName: "sqlite3",
+ DBPath: ":memory:",
+ MigrationsPath: "../db/db_sqlite3/migrations/",
+ }
+ err := models.Setup(conf)
if err != nil {
s.T().Fatalf("Failed creating database: %v", err)
}
+ s.config = conf
s.Nil(err)
// Setup the admin server for use in testing
- as.Config.Addr = config.Conf.AdminConf.ListenURL
- as.Start()
+ s.adminServer = httptest.NewUnstartedServer(NewAdminServer(s.config.AdminConf).server.Handler)
+ s.adminServer.Config.Addr = s.config.AdminConf.ListenURL
+ s.adminServer.Start()
// Get the API key to use for these tests
u, err := models.GetUser(1)
s.Nil(err)
s.ApiKey = u.ApiKey
// Start the phishing server
- ps.Config.Addr = config.Conf.PhishConf.ListenURL
- ps.Start()
+ s.phishServer = httptest.NewUnstartedServer(NewPhishingServer(s.config.PhishConf).server.Handler)
+ s.phishServer.Config.Addr = s.config.PhishConf.ListenURL
+ s.phishServer.Start()
// Move our cwd up to the project root for help with resolving
// static assets
err = os.Chdir("../")
@@ -103,21 +104,21 @@ func (s *ControllersSuite) SetupTest() {
}
func (s *ControllersSuite) TestRequireAPIKey() {
- resp, err := http.Post(fmt.Sprintf("%s/api/import/site", as.URL), "application/json", nil)
+ resp, err := http.Post(fmt.Sprintf("%s/api/import/site", s.adminServer.URL), "application/json", nil)
s.Nil(err)
defer resp.Body.Close()
- s.Equal(resp.StatusCode, http.StatusBadRequest)
+ s.Equal(resp.StatusCode, http.StatusUnauthorized)
}
func (s *ControllersSuite) TestInvalidAPIKey() {
- resp, err := http.Get(fmt.Sprintf("%s/api/groups/?api_key=%s", as.URL, "bogus-api-key"))
+ resp, err := http.Get(fmt.Sprintf("%s/api/groups/?api_key=%s", s.adminServer.URL, "bogus-api-key"))
s.Nil(err)
defer resp.Body.Close()
- s.Equal(resp.StatusCode, http.StatusBadRequest)
+ s.Equal(resp.StatusCode, http.StatusUnauthorized)
}
func (s *ControllersSuite) TestBearerToken() {
- req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/groups/", as.URL), nil)
+ req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/groups/", s.adminServer.URL), nil)
s.Nil(err)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.ApiKey))
resp, err := http.DefaultClient.Do(req)
@@ -133,7 +134,7 @@ func (s *ControllersSuite) TestSiteImportBaseHref() {
}))
hr := fmt.Sprintf("
\n", ts.URL)
defer ts.Close()
- resp, err := http.Post(fmt.Sprintf("%s/api/import/site?api_key=%s", as.URL, s.ApiKey), "application/json",
+ resp, err := http.Post(fmt.Sprintf("%s/api/import/site?api_key=%s", s.adminServer.URL, s.ApiKey), "application/json",
bytes.NewBuffer([]byte(fmt.Sprintf(`
{
"url" : "%s",
@@ -150,8 +151,8 @@ func (s *ControllersSuite) TestSiteImportBaseHref() {
func (s *ControllersSuite) TearDownSuite() {
// Tear down the admin and phishing servers
- as.Close()
- ps.Close()
+ s.adminServer.Close()
+ s.phishServer.Close()
}
func TestControllerSuite(t *testing.T) {
diff --git a/controllers/phish.go b/controllers/phish.go
index 04165829..e1996d3d 100644
--- a/controllers/phish.go
+++ b/controllers/phish.go
@@ -1,6 +1,8 @@
package controllers
import (
+ "compress/gzip"
+ "context"
"errors"
"fmt"
"net"
@@ -8,11 +10,15 @@ import (
"strings"
"time"
+ "github.com/NYTimes/gziphandler"
"github.com/gophish/gophish/config"
ctx "github.com/gophish/gophish/context"
log "github.com/gophish/gophish/logger"
"github.com/gophish/gophish/models"
+ "github.com/gophish/gophish/util"
+ "github.com/gorilla/handlers"
"github.com/gorilla/mux"
+ "github.com/jordan-wright/unindexed"
)
// ErrInvalidRequest is thrown when a request with an invalid structure is
@@ -35,22 +41,91 @@ type TransparencyResponse struct {
// to return a transparency response.
const TransparencySuffix = "+"
-// CreatePhishingRouter creates the router that handles phishing connections.
-func CreatePhishingRouter() http.Handler {
- router := mux.NewRouter()
- fileServer := http.FileServer(UnindexedFileSystem{http.Dir("./static/endpoint/")})
- router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", fileServer))
- router.HandleFunc("/track", PhishTracker)
- router.HandleFunc("/robots.txt", RobotsHandler)
- router.HandleFunc("/{path:.*}/track", PhishTracker)
- router.HandleFunc("/{path:.*}/report", PhishReporter)
- router.HandleFunc("/report", PhishReporter)
- router.HandleFunc("/{path:.*}", PhishHandler)
- return router
+// PhishingServerOption is a functional option that is used to configure the
+// the phishing server
+type PhishingServerOption func(*PhishingServer)
+
+// PhishingServer is an HTTP server that implements the campaign event
+// handlers, such as email open tracking, click tracking, and more.
+type PhishingServer struct {
+ server *http.Server
+ config config.PhishServer
+ contactAddress string
}
-// PhishTracker tracks emails as they are opened, updating the status for the given Result
-func PhishTracker(w http.ResponseWriter, r *http.Request) {
+// NewPhishingServer returns a new instance of the phishing server with
+// provided options applied.
+func NewPhishingServer(config config.PhishServer, options ...PhishingServerOption) *PhishingServer {
+ defaultServer := &http.Server{
+ ReadTimeout: 10 * time.Second,
+ WriteTimeout: 10 * time.Second,
+ Addr: config.ListenURL,
+ }
+ ps := &PhishingServer{
+ server: defaultServer,
+ config: config,
+ }
+ for _, opt := range options {
+ opt(ps)
+ }
+ ps.registerRoutes()
+ return ps
+}
+
+// WithContactAddress sets the contact address used by the transparency
+// handlers
+func WithContactAddress(addr string) PhishingServerOption {
+ return func(ps *PhishingServer) {
+ ps.contactAddress = addr
+ }
+}
+
+// Start launches the phishing server, listening on the configured address.
+func (ps *PhishingServer) Start() error {
+ if ps.config.UseTLS {
+ err := util.CheckAndCreateSSL(ps.config.CertPath, ps.config.KeyPath)
+ if err != nil {
+ log.Fatal(err)
+ return err
+ }
+ log.Infof("Starting phishing server at https://%s", ps.config.ListenURL)
+ return ps.server.ListenAndServeTLS(ps.config.CertPath, ps.config.KeyPath)
+ }
+ // If TLS isn't configured, just listen on HTTP
+ log.Infof("Starting phishing server at http://%s", ps.config.ListenURL)
+ return ps.server.ListenAndServe()
+}
+
+// Shutdown attempts to gracefully shutdown the server.
+func (ps *PhishingServer) Shutdown() error {
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
+ defer cancel()
+ return ps.server.Shutdown(ctx)
+}
+
+// CreatePhishingRouter creates the router that handles phishing connections.
+func (ps *PhishingServer) registerRoutes() {
+ router := mux.NewRouter()
+ fileServer := http.FileServer(unindexed.Dir("./static/endpoint/"))
+ router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", fileServer))
+ router.HandleFunc("/track", ps.TrackHandler)
+ router.HandleFunc("/robots.txt", ps.RobotsHandler)
+ router.HandleFunc("/{path:.*}/track", ps.TrackHandler)
+ router.HandleFunc("/{path:.*}/report", ps.ReportHandler)
+ router.HandleFunc("/report", ps.ReportHandler)
+ router.HandleFunc("/{path:.*}", ps.PhishHandler)
+
+ // Setup GZIP compression
+ gzipWrapper, _ := gziphandler.NewGzipLevelHandler(gzip.BestCompression)
+ phishHandler := gzipWrapper(router)
+
+ // Setup logging
+ phishHandler = handlers.CombinedLoggingHandler(log.Writer(), phishHandler)
+ ps.server.Handler = router
+}
+
+// TrackHandler tracks emails as they are opened, updating the status for the given Result
+func (ps *PhishingServer) TrackHandler(w http.ResponseWriter, r *http.Request) {
err, r := setupContext(r)
if err != nil {
// Log the error if it wasn't something we can safely ignore
@@ -71,7 +146,7 @@ func PhishTracker(w http.ResponseWriter, r *http.Request) {
// Check for a transparency request
if strings.HasSuffix(rid, TransparencySuffix) {
- TransparencyHandler(w, r)
+ ps.TransparencyHandler(w, r)
return
}
@@ -82,8 +157,8 @@ func PhishTracker(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "static/images/pixel.png")
}
-// PhishReporter tracks emails as they are reported, updating the status for the given Result
-func PhishReporter(w http.ResponseWriter, r *http.Request) {
+// ReportHandler tracks emails as they are reported, updating the status for the given Result
+func (ps *PhishingServer) ReportHandler(w http.ResponseWriter, r *http.Request) {
err, r := setupContext(r)
if err != nil {
// Log the error if it wasn't something we can safely ignore
@@ -104,7 +179,7 @@ func PhishReporter(w http.ResponseWriter, r *http.Request) {
// Check for a transparency request
if strings.HasSuffix(rid, TransparencySuffix) {
- TransparencyHandler(w, r)
+ ps.TransparencyHandler(w, r)
return
}
@@ -117,7 +192,7 @@ func PhishReporter(w http.ResponseWriter, r *http.Request) {
// PhishHandler handles incoming client connections and registers the associated actions performed
// (such as clicked link, etc.)
-func PhishHandler(w http.ResponseWriter, r *http.Request) {
+func (ps *PhishingServer) PhishHandler(w http.ResponseWriter, r *http.Request) {
err, r := setupContext(r)
if err != nil {
// Log the error if it wasn't something we can safely ignore
@@ -152,7 +227,7 @@ func PhishHandler(w http.ResponseWriter, r *http.Request) {
// Check for a transparency request
if strings.HasSuffix(rid, TransparencySuffix) {
- TransparencyHandler(w, r)
+ ps.TransparencyHandler(w, r)
return
}
@@ -211,23 +286,24 @@ func renderPhishResponse(w http.ResponseWriter, r *http.Request, ptx models.Phis
}
// RobotsHandler prevents search engines, etc. from indexing phishing materials
-func RobotsHandler(w http.ResponseWriter, r *http.Request) {
+func (ps *PhishingServer) RobotsHandler(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "User-agent: *\nDisallow: /")
}
// TransparencyHandler returns a TransparencyResponse for the provided result
// and campaign.
-func TransparencyHandler(w http.ResponseWriter, r *http.Request) {
+func (ps *PhishingServer) TransparencyHandler(w http.ResponseWriter, r *http.Request) {
rs := ctx.Get(r, "result").(models.Result)
tr := &TransparencyResponse{
Server: config.ServerName,
SendDate: rs.SendDate,
- ContactAddress: config.Conf.ContactAddress,
+ ContactAddress: ps.contactAddress,
}
JSONResponse(w, tr, http.StatusOK)
}
-// setupContext handles some of the administrative work around receiving a new request, such as checking the result ID, the campaign, etc.
+// setupContext handles some of the administrative work around receiving a new
+// request, such as checking the result ID, the campaign, etc.
func setupContext(r *http.Request) (error, *http.Request) {
err := r.ParseForm()
if err != nil {
diff --git a/controllers/phish_test.go b/controllers/phish_test.go
index 59425c8a..dca961a6 100644
--- a/controllers/phish_test.go
+++ b/controllers/phish_test.go
@@ -38,7 +38,7 @@ func (s *ControllersSuite) getFirstEmailRequest() models.EmailRequest {
}
func (s *ControllersSuite) openEmail(rid string) {
- resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", ps.URL, models.RecipientParameter, rid))
+ resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
@@ -49,19 +49,19 @@ func (s *ControllersSuite) openEmail(rid string) {
}
func (s *ControllersSuite) reportedEmail(rid string) {
- resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", ps.URL, models.RecipientParameter, rid))
+ resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err)
s.Equal(resp.StatusCode, http.StatusNoContent)
}
func (s *ControllersSuite) reportEmail404(rid string) {
- resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", ps.URL, models.RecipientParameter, rid))
+ resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err)
s.Equal(resp.StatusCode, http.StatusNotFound)
}
func (s *ControllersSuite) openEmail404(rid string) {
- resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", ps.URL, models.RecipientParameter, rid))
+ resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err)
defer resp.Body.Close()
s.Nil(err)
@@ -69,7 +69,7 @@ func (s *ControllersSuite) openEmail404(rid string) {
}
func (s *ControllersSuite) clickLink(rid string, expectedHTML string) {
- resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", ps.URL, models.RecipientParameter, rid))
+ resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err)
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
@@ -79,7 +79,7 @@ func (s *ControllersSuite) clickLink(rid string, expectedHTML string) {
}
func (s *ControllersSuite) clickLink404(rid string) {
- resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", ps.URL, models.RecipientParameter, rid))
+ resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err)
defer resp.Body.Close()
s.Nil(err)
@@ -87,14 +87,14 @@ func (s *ControllersSuite) clickLink404(rid string) {
}
func (s *ControllersSuite) transparencyRequest(r models.Result, rid, path string) {
- resp, err := http.Get(fmt.Sprintf("%s%s?%s=%s", ps.URL, path, models.RecipientParameter, rid))
+ resp, err := http.Get(fmt.Sprintf("%s%s?%s=%s", s.phishServer.URL, path, models.RecipientParameter, rid))
s.Nil(err)
defer resp.Body.Close()
s.Equal(resp.StatusCode, http.StatusOK)
tr := &TransparencyResponse{}
err = json.NewDecoder(resp.Body).Decode(tr)
s.Nil(err)
- s.Equal(tr.ContactAddress, config.Conf.ContactAddress)
+ s.Equal(tr.ContactAddress, s.config.ContactAddress)
s.Equal(tr.SendDate, r.SendDate)
s.Equal(tr.Server, config.ServerName)
}
@@ -146,11 +146,11 @@ func (s *ControllersSuite) TestClickedPhishingLinkAfterOpen() {
}
func (s *ControllersSuite) TestNoRecipientID() {
- resp, err := http.Get(fmt.Sprintf("%s/track", ps.URL))
+ resp, err := http.Get(fmt.Sprintf("%s/track", s.phishServer.URL))
s.Nil(err)
s.Equal(resp.StatusCode, http.StatusNotFound)
- resp, err = http.Get(ps.URL)
+ resp, err = http.Get(s.phishServer.URL)
s.Nil(err)
s.Equal(resp.StatusCode, http.StatusNotFound)
}
@@ -183,7 +183,7 @@ func (s *ControllersSuite) TestCompletedCampaignClick() {
func (s *ControllersSuite) TestRobotsHandler() {
expected := []byte("User-agent: *\nDisallow: /\n")
- resp, err := http.Get(fmt.Sprintf("%s/robots.txt", ps.URL))
+ resp, err := http.Get(fmt.Sprintf("%s/robots.txt", s.phishServer.URL))
s.Nil(err)
s.Equal(resp.StatusCode, http.StatusOK)
defer resp.Body.Close()
@@ -259,7 +259,7 @@ func (s *ControllersSuite) TestRedirectTemplating() {
},
}
result := campaign.Results[0]
- resp, err := client.PostForm(fmt.Sprintf("%s/?%s=%s", ps.URL, models.RecipientParameter, result.RId), url.Values{"username": {"test"}, "password": {"test"}})
+ resp, err := client.PostForm(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, result.RId), url.Values{"username": {"test"}, "password": {"test"}})
s.Nil(err)
defer resp.Body.Close()
s.Equal(http.StatusFound, resp.StatusCode)
diff --git a/controllers/route.go b/controllers/route.go
index ff31ece7..c27499d8 100644
--- a/controllers/route.go
+++ b/controllers/route.go
@@ -1,72 +1,153 @@
package controllers
import (
- "fmt"
+ "compress/gzip"
+ "context"
"html/template"
"net/http"
"net/url"
+ "time"
+ "github.com/NYTimes/gziphandler"
"github.com/gophish/gophish/auth"
"github.com/gophish/gophish/config"
ctx "github.com/gophish/gophish/context"
log "github.com/gophish/gophish/logger"
mid "github.com/gophish/gophish/middleware"
"github.com/gophish/gophish/models"
+ "github.com/gophish/gophish/util"
+ "github.com/gophish/gophish/worker"
"github.com/gorilla/csrf"
+ "github.com/gorilla/handlers"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
+ "github.com/jordan-wright/unindexed"
)
-// CreateAdminRouter creates the routes for handling requests to the web interface.
+// AdminServerOption is a functional option that is used to configure the
+// admin server
+type AdminServerOption func(*AdminServer)
+
+// AdminServer is an HTTP server that implements the administrative Gophish
+// handlers, including the dashboard and REST API.
+type AdminServer struct {
+ server *http.Server
+ worker worker.Worker
+ config config.AdminServer
+}
+
+// WithWorker is an option that sets the background worker.
+func WithWorker(w worker.Worker) AdminServerOption {
+ return func(as *AdminServer) {
+ as.worker = w
+ }
+}
+
+// NewAdminServer returns a new instance of the AdminServer with the
+// provided config and options applied.
+func NewAdminServer(config config.AdminServer, options ...AdminServerOption) *AdminServer {
+ defaultWorker, _ := worker.New()
+ defaultServer := &http.Server{
+ ReadTimeout: 10 * time.Second,
+ Addr: config.ListenURL,
+ }
+ as := &AdminServer{
+ worker: defaultWorker,
+ server: defaultServer,
+ config: config,
+ }
+ for _, opt := range options {
+ opt(as)
+ }
+ as.registerRoutes()
+ return as
+}
+
+// Start launches the admin server, listening on the configured address.
+func (as *AdminServer) Start() error {
+ if as.worker != nil {
+ go as.worker.Start()
+ }
+ if as.config.UseTLS {
+ err := util.CheckAndCreateSSL(as.config.CertPath, as.config.KeyPath)
+ if err != nil {
+ log.Fatal(err)
+ return err
+ }
+ log.Infof("Starting admin server at https://%s", as.config.ListenURL)
+ return as.server.ListenAndServeTLS(as.config.CertPath, as.config.KeyPath)
+ }
+ // If TLS isn't configured, just listen on HTTP
+ log.Infof("Starting admin server at http://%s", as.config.ListenURL)
+ return as.server.ListenAndServe()
+}
+
+// Shutdown attempts to gracefully shutdown the server.
+func (as *AdminServer) Shutdown() error {
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
+ defer cancel()
+ return as.server.Shutdown(ctx)
+}
+
+// SetupAdminRoutes creates the routes for handling requests to the web interface.
// This function returns an http.Handler to be used in http.ListenAndServe().
-func CreateAdminRouter() http.Handler {
+func (as *AdminServer) registerRoutes() {
router := mux.NewRouter()
// Base Front-end routes
- router.HandleFunc("/", Use(Base, mid.RequireLogin))
- router.HandleFunc("/login", Login)
- router.HandleFunc("/logout", Use(Logout, mid.RequireLogin))
- router.HandleFunc("/campaigns", Use(Campaigns, mid.RequireLogin))
- router.HandleFunc("/campaigns/{id:[0-9]+}", Use(CampaignID, mid.RequireLogin))
- router.HandleFunc("/templates", Use(Templates, mid.RequireLogin))
- router.HandleFunc("/users", Use(Users, mid.RequireLogin))
- router.HandleFunc("/landing_pages", Use(LandingPages, mid.RequireLogin))
- router.HandleFunc("/sending_profiles", Use(SendingProfiles, mid.RequireLogin))
- router.HandleFunc("/register", Use(Register, mid.RequireLogin))
- router.HandleFunc("/settings", Use(Settings, mid.RequireLogin))
+ router.HandleFunc("/", Use(as.Base, mid.RequireLogin))
+ router.HandleFunc("/login", as.Login)
+ router.HandleFunc("/logout", Use(as.Logout, mid.RequireLogin))
+ router.HandleFunc("/campaigns", Use(as.Campaigns, mid.RequireLogin))
+ router.HandleFunc("/campaigns/{id:[0-9]+}", Use(as.CampaignID, mid.RequireLogin))
+ router.HandleFunc("/templates", Use(as.Templates, mid.RequireLogin))
+ router.HandleFunc("/users", Use(as.Users, mid.RequireLogin))
+ router.HandleFunc("/landing_pages", Use(as.LandingPages, mid.RequireLogin))
+ router.HandleFunc("/sending_profiles", Use(as.SendingProfiles, mid.RequireLogin))
+ router.HandleFunc("/register", Use(as.Register, mid.RequireLogin))
+ router.HandleFunc("/settings", Use(as.Settings, mid.RequireLogin))
// Create the API routes
api := router.PathPrefix("/api").Subrouter()
api = api.StrictSlash(true)
- api.HandleFunc("/reset", Use(API_Reset, mid.RequireAPIKey))
- api.HandleFunc("/campaigns/", Use(API_Campaigns, mid.RequireAPIKey))
- api.HandleFunc("/campaigns/summary", Use(API_Campaigns_Summary, mid.RequireAPIKey))
- api.HandleFunc("/campaigns/{id:[0-9]+}", Use(API_Campaigns_Id, mid.RequireAPIKey))
- api.HandleFunc("/campaigns/{id:[0-9]+}/results", Use(API_Campaigns_Id_Results, mid.RequireAPIKey))
- api.HandleFunc("/campaigns/{id:[0-9]+}/summary", Use(API_Campaign_Id_Summary, mid.RequireAPIKey))
- api.HandleFunc("/campaigns/{id:[0-9]+}/complete", Use(API_Campaigns_Id_Complete, mid.RequireAPIKey))
- api.HandleFunc("/groups/", Use(API_Groups, mid.RequireAPIKey))
- api.HandleFunc("/groups/summary", Use(API_Groups_Summary, mid.RequireAPIKey))
- api.HandleFunc("/groups/{id:[0-9]+}", Use(API_Groups_Id, mid.RequireAPIKey))
- api.HandleFunc("/groups/{id:[0-9]+}/summary", Use(API_Groups_Id_Summary, mid.RequireAPIKey))
- api.HandleFunc("/templates/", Use(API_Templates, mid.RequireAPIKey))
- api.HandleFunc("/templates/{id:[0-9]+}", Use(API_Templates_Id, mid.RequireAPIKey))
- api.HandleFunc("/pages/", Use(API_Pages, mid.RequireAPIKey))
- api.HandleFunc("/pages/{id:[0-9]+}", Use(API_Pages_Id, mid.RequireAPIKey))
- api.HandleFunc("/smtp/", Use(API_SMTP, mid.RequireAPIKey))
- api.HandleFunc("/smtp/{id:[0-9]+}", Use(API_SMTP_Id, mid.RequireAPIKey))
- api.HandleFunc("/util/send_test_email", Use(API_Send_Test_Email, mid.RequireAPIKey))
- api.HandleFunc("/import/group", Use(API_Import_Group, mid.RequireAPIKey))
- api.HandleFunc("/import/email", Use(API_Import_Email, mid.RequireAPIKey))
- api.HandleFunc("/import/site", Use(API_Import_Site, mid.RequireAPIKey))
+ api.Use(mid.RequireAPIKey)
+ api.HandleFunc("/reset", as.API_Reset)
+ api.HandleFunc("/campaigns/", as.API_Campaigns)
+ api.HandleFunc("/campaigns/summary", as.API_Campaigns_Summary)
+ api.HandleFunc("/campaigns/{id:[0-9]+}", as.API_Campaigns_Id)
+ api.HandleFunc("/campaigns/{id:[0-9]+}/results", as.API_Campaigns_Id_Results)
+ api.HandleFunc("/campaigns/{id:[0-9]+}/summary", as.API_Campaign_Id_Summary)
+ api.HandleFunc("/campaigns/{id:[0-9]+}/complete", as.API_Campaigns_Id_Complete)
+ api.HandleFunc("/groups/", as.API_Groups)
+ api.HandleFunc("/groups/summary", as.API_Groups_Summary)
+ api.HandleFunc("/groups/{id:[0-9]+}", as.API_Groups_Id)
+ api.HandleFunc("/groups/{id:[0-9]+}/summary", as.API_Groups_Id_Summary)
+ api.HandleFunc("/templates/", as.API_Templates)
+ api.HandleFunc("/templates/{id:[0-9]+}", as.API_Templates_Id)
+ api.HandleFunc("/pages/", as.API_Pages)
+ api.HandleFunc("/pages/{id:[0-9]+}", as.API_Pages_Id)
+ api.HandleFunc("/smtp/", as.API_SMTP)
+ api.HandleFunc("/smtp/{id:[0-9]+}", as.API_SMTP_Id)
+ api.HandleFunc("/util/send_test_email", as.API_Send_Test_Email)
+ api.HandleFunc("/import/group", as.API_Import_Group)
+ api.HandleFunc("/import/email", as.API_Import_Email)
+ api.HandleFunc("/import/site", as.API_Import_Site)
// Setup static file serving
- router.PathPrefix("/").Handler(http.FileServer(UnindexedFileSystem{http.Dir("./static/")}))
+ router.PathPrefix("/").Handler(http.FileServer(unindexed.Dir("./static/")))
// Setup CSRF Protection
csrfHandler := csrf.Protect([]byte(auth.GenerateSecureKey()),
csrf.FieldName("csrf_token"),
- csrf.Secure(config.Conf.AdminConf.UseTLS))
- csrfRouter := csrfHandler(router)
- return Use(csrfRouter.ServeHTTP, mid.CSRFExceptions, mid.GetContext)
+ csrf.Secure(as.config.UseTLS))
+ adminHandler := csrfHandler(router)
+ adminHandler = Use(adminHandler.ServeHTTP, mid.CSRFExceptions, mid.GetContext)
+
+ // Setup GZIP compression
+ gzipWrapper, _ := gziphandler.NewGzipLevelHandler(gzip.BestCompression)
+ adminHandler = gzipWrapper(adminHandler)
+
+ // Setup logging
+ adminHandler = handlers.CombinedLoggingHandler(log.Writer(), adminHandler)
+ as.server.Handler = adminHandler
}
// Use allows us to stack middleware to process the request
@@ -78,16 +159,29 @@ func Use(handler http.HandlerFunc, mid ...func(http.Handler) http.HandlerFunc) h
return handler
}
+type templateParams struct {
+ Title string
+ Flashes []interface{}
+ User models.User
+ Token string
+ Version string
+}
+
+// newTemplateParams returns the default template parameters for a user and
+// the CSRF token.
+func newTemplateParams(r *http.Request) templateParams {
+ return templateParams{
+ Token: csrf.Token(r),
+ User: ctx.Get(r, "user").(models.User),
+ Version: config.Version,
+ }
+}
+
// Register creates a new user
-func Register(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) Register(w http.ResponseWriter, r *http.Request) {
// If it is a post request, attempt to register the account
// Now that we are all registered, we can log the user in
- params := struct {
- Title string
- Flashes []interface{}
- User models.User
- Token string
- }{Title: "Register", Token: csrf.Token(r)}
+ params := templateParams{Title: "Register", Token: csrf.Token(r)}
session := ctx.Get(r, "session").(*sessions.Session)
switch {
case r.Method == "GET":
@@ -120,99 +214,60 @@ func Register(w http.ResponseWriter, r *http.Request) {
}
// Base handles the default path and template execution
-func Base(w http.ResponseWriter, r *http.Request) {
- params := struct {
- User models.User
- Title string
- Flashes []interface{}
- Token string
- }{Title: "Dashboard", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
+func (as *AdminServer) Base(w http.ResponseWriter, r *http.Request) {
+ params := newTemplateParams(r)
+ params.Title = "Dashboard"
getTemplate(w, "dashboard").ExecuteTemplate(w, "base", params)
}
// Campaigns handles the default path and template execution
-func Campaigns(w http.ResponseWriter, r *http.Request) {
- // Example of using session - will be removed.
- params := struct {
- User models.User
- Title string
- Flashes []interface{}
- Token string
- }{Title: "Campaigns", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
+func (as *AdminServer) Campaigns(w http.ResponseWriter, r *http.Request) {
+ params := newTemplateParams(r)
+ params.Title = "Campaigns"
getTemplate(w, "campaigns").ExecuteTemplate(w, "base", params)
}
// CampaignID handles the default path and template execution
-func CampaignID(w http.ResponseWriter, r *http.Request) {
- // Example of using session - will be removed.
- params := struct {
- User models.User
- Title string
- Flashes []interface{}
- Token string
- }{Title: "Campaign Results", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
+func (as *AdminServer) CampaignID(w http.ResponseWriter, r *http.Request) {
+ params := newTemplateParams(r)
+ params.Title = "Campaign Results"
getTemplate(w, "campaign_results").ExecuteTemplate(w, "base", params)
}
// Templates handles the default path and template execution
-func Templates(w http.ResponseWriter, r *http.Request) {
- // Example of using session - will be removed.
- params := struct {
- User models.User
- Title string
- Flashes []interface{}
- Token string
- }{Title: "Email Templates", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
+func (as *AdminServer) Templates(w http.ResponseWriter, r *http.Request) {
+ params := newTemplateParams(r)
+ params.Title = "Email Templates"
getTemplate(w, "templates").ExecuteTemplate(w, "base", params)
}
// Users handles the default path and template execution
-func Users(w http.ResponseWriter, r *http.Request) {
- // Example of using session - will be removed.
- params := struct {
- User models.User
- Title string
- Flashes []interface{}
- Token string
- }{Title: "Users & Groups", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
+func (as *AdminServer) Users(w http.ResponseWriter, r *http.Request) {
+ params := newTemplateParams(r)
+ params.Title = "Users & Groups"
getTemplate(w, "users").ExecuteTemplate(w, "base", params)
}
// LandingPages handles the default path and template execution
-func LandingPages(w http.ResponseWriter, r *http.Request) {
- // Example of using session - will be removed.
- params := struct {
- User models.User
- Title string
- Flashes []interface{}
- Token string
- }{Title: "Landing Pages", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
+func (as *AdminServer) LandingPages(w http.ResponseWriter, r *http.Request) {
+ params := newTemplateParams(r)
+ params.Title = "Landing Pages"
getTemplate(w, "landing_pages").ExecuteTemplate(w, "base", params)
}
// SendingProfiles handles the default path and template execution
-func SendingProfiles(w http.ResponseWriter, r *http.Request) {
- // Example of using session - will be removed.
- params := struct {
- User models.User
- Title string
- Flashes []interface{}
- Token string
- }{Title: "Sending Profiles", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
+func (as *AdminServer) SendingProfiles(w http.ResponseWriter, r *http.Request) {
+ params := newTemplateParams(r)
+ params.Title = "Sending Profiles"
getTemplate(w, "sending_profiles").ExecuteTemplate(w, "base", params)
}
// Settings handles the changing of settings
-func Settings(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) Settings(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == "GET":
- params := struct {
- User models.User
- Title string
- Flashes []interface{}
- Token string
- Version string
- }{Title: "Settings", Version: config.Version, User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
+ params := newTemplateParams(r)
+ params.Title = "Settings"
getTemplate(w, "settings").ExecuteTemplate(w, "base", params)
case r.Method == "POST":
err := auth.ChangePassword(r)
@@ -235,7 +290,7 @@ func Settings(w http.ResponseWriter, r *http.Request) {
// Login handles the authentication flow for a user. If credentials are valid,
// a session is created
-func Login(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) Login(w http.ResponseWriter, r *http.Request) {
params := struct {
User models.User
Title string
@@ -289,7 +344,7 @@ func Login(w http.ResponseWriter, r *http.Request) {
}
// Logout destroys the current user session
-func Logout(w http.ResponseWriter, r *http.Request) {
+func (as *AdminServer) Logout(w http.ResponseWriter, r *http.Request) {
session := ctx.Get(r, "session").(*sessions.Session)
delete(session.Values, "id")
Flash(w, r, "success", "You have successfully logged out")
@@ -297,28 +352,6 @@ func Logout(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/login", 302)
}
-// Preview allows for the viewing of page html in a separate browser window
-func Preview(w http.ResponseWriter, r *http.Request) {
- if r.Method != "POST" {
- http.Error(w, "Method not allowed", http.StatusBadRequest)
- return
- }
- fmt.Fprintf(w, "%s", r.FormValue("html"))
-}
-
-// Clone takes a URL as a POST parameter and returns the site HTML
-func Clone(w http.ResponseWriter, r *http.Request) {
- vars := mux.Vars(r)
- if r.Method != "POST" {
- http.Error(w, "Method not allowed", http.StatusBadRequest)
- return
- }
- if url, ok := vars["url"]; ok {
- log.Error(url)
- }
- http.Error(w, "No URL given.", http.StatusBadRequest)
-}
-
func getTemplate(w http.ResponseWriter, tmpl string) *template.Template {
templates := template.New("template")
_, err := templates.ParseFiles("templates/base.html", "templates/"+tmpl+".html", "templates/flashes.html")
diff --git a/controllers/route_test.go b/controllers/route_test.go
index 09d7ec45..c9c7975e 100644
--- a/controllers/route_test.go
+++ b/controllers/route_test.go
@@ -10,7 +10,7 @@ import (
)
func (s *ControllersSuite) TestLoginCSRF() {
- resp, err := http.PostForm(fmt.Sprintf("%s/login", as.URL),
+ resp, err := http.PostForm(fmt.Sprintf("%s/login", s.adminServer.URL),
url.Values{
"username": {"admin"},
"password": {"gophish"},
@@ -21,7 +21,7 @@ func (s *ControllersSuite) TestLoginCSRF() {
}
func (s *ControllersSuite) TestInvalidCredentials() {
- resp, err := http.Get(fmt.Sprintf("%s/login", as.URL))
+ resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL))
s.Equal(err, nil)
s.Equal(resp.StatusCode, http.StatusOK)
@@ -32,7 +32,7 @@ func (s *ControllersSuite) TestInvalidCredentials() {
s.Equal(ok, true)
client := &http.Client{}
- req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", as.URL), strings.NewReader(url.Values{
+ req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", s.adminServer.URL), strings.NewReader(url.Values{
"username": {"admin"},
"password": {"invalid"},
"csrf_token": {token},
@@ -48,7 +48,7 @@ func (s *ControllersSuite) TestInvalidCredentials() {
}
func (s *ControllersSuite) TestSuccessfulLogin() {
- resp, err := http.Get(fmt.Sprintf("%s/login", as.URL))
+ resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL))
s.Equal(err, nil)
s.Equal(resp.StatusCode, http.StatusOK)
@@ -59,7 +59,7 @@ func (s *ControllersSuite) TestSuccessfulLogin() {
s.Equal(ok, true)
client := &http.Client{}
- req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", as.URL), strings.NewReader(url.Values{
+ req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", s.adminServer.URL), strings.NewReader(url.Values{
"username": {"admin"},
"password": {"gophish"},
"csrf_token": {token},
@@ -76,7 +76,7 @@ func (s *ControllersSuite) TestSuccessfulLogin() {
func (s *ControllersSuite) TestSuccessfulRedirect() {
next := "/campaigns"
- resp, err := http.Get(fmt.Sprintf("%s/login", as.URL))
+ resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL))
s.Equal(err, nil)
s.Equal(resp.StatusCode, http.StatusOK)
@@ -91,7 +91,7 @@ func (s *ControllersSuite) TestSuccessfulRedirect() {
return http.ErrUseLastResponse
},
}
- req, err := http.NewRequest("POST", fmt.Sprintf("%s/login?next=%s", as.URL, next), strings.NewReader(url.Values{
+ req, err := http.NewRequest("POST", fmt.Sprintf("%s/login?next=%s", s.adminServer.URL, next), strings.NewReader(url.Values{
"username": {"admin"},
"password": {"gophish"},
"csrf_token": {token},
diff --git a/controllers/static.go b/controllers/static.go
deleted file mode 100644
index 4cdcbd6d..00000000
--- a/controllers/static.go
+++ /dev/null
@@ -1,35 +0,0 @@
-package controllers
-
-import (
- "net/http"
- "strings"
-)
-
-// UnindexedFileSystem is an implementation of a standard http.FileSystem
-// without the ability to list files in the directory.
-// This implementation is largely inspired by
-// https://www.alexedwards.net/blog/disable-http-fileserver-directory-listings
-type UnindexedFileSystem struct {
- fs http.FileSystem
-}
-
-// Open returns a file from the static directory. If the requested path ends
-// with a slash, there is a check for an index.html file. If none exists, then
-// an error is returned.
-func (ufs UnindexedFileSystem) Open(name string) (http.File, error) {
- f, err := ufs.fs.Open(name)
- if err != nil {
- return nil, err
- }
-
- s, err := f.Stat()
- if s.IsDir() {
- index := strings.TrimSuffix(name, "/") + "/index.html"
- indexFile, err := ufs.fs.Open(index)
- if err != nil {
- return nil, err
- }
- return indexFile, nil
- }
- return f, nil
-}
diff --git a/controllers/static_test.go b/controllers/static_test.go
deleted file mode 100644
index 0f36464f..00000000
--- a/controllers/static_test.go
+++ /dev/null
@@ -1,81 +0,0 @@
-package controllers
-
-import (
- "bytes"
- "fmt"
- "io/ioutil"
- "net/http"
- "os"
- "path/filepath"
-)
-
-var fileContent = []byte("Hello world")
-
-func mustRemoveAll(dir string) {
- err := os.RemoveAll(dir)
- if err != nil {
- panic(err)
- }
-}
-
-func createTestFile(dir, filename string) error {
- return ioutil.WriteFile(filepath.Join(dir, filename), fileContent, 0644)
-}
-
-func (s *ControllersSuite) TestGetStaticFile() {
- dir, err := ioutil.TempDir("static/endpoint", "test-")
- tempFolder := filepath.Base(dir)
-
- s.Nil(err)
- defer mustRemoveAll(dir)
-
- err = createTestFile(dir, "foo.txt")
- s.Nil(nil, err)
-
- resp, err := http.Get(fmt.Sprintf("%s/static/%s/foo.txt", ps.URL, tempFolder))
- s.Nil(err)
-
- defer resp.Body.Close()
- got, err := ioutil.ReadAll(resp.Body)
- s.Nil(err)
-
- s.Equal(bytes.Compare(fileContent, got), 0, fmt.Sprintf("Got %s", got))
-}
-
-func (s *ControllersSuite) TestStaticFileListing() {
- dir, err := ioutil.TempDir("static/endpoint", "test-")
- tempFolder := filepath.Base(dir)
-
- s.Nil(err)
- defer mustRemoveAll(dir)
-
- err = createTestFile(dir, "foo.txt")
- s.Nil(nil, err)
-
- resp, err := http.Get(fmt.Sprintf("%s/static/%s/", ps.URL, tempFolder))
- s.Nil(err)
-
- defer resp.Body.Close()
- s.Nil(err)
- s.Equal(resp.StatusCode, http.StatusNotFound)
-}
-
-func (s *ControllersSuite) TestStaticIndex() {
- dir, err := ioutil.TempDir("static/endpoint", "test-")
- tempFolder := filepath.Base(dir)
-
- s.Nil(err)
- defer mustRemoveAll(dir)
-
- err = createTestFile(dir, "index.html")
- s.Nil(nil, err)
-
- resp, err := http.Get(fmt.Sprintf("%s/static/%s/", ps.URL, tempFolder))
- s.Nil(err)
-
- defer resp.Body.Close()
- got, err := ioutil.ReadAll(resp.Body)
- s.Nil(err)
-
- s.Equal(bytes.Compare(fileContent, got), 0, fmt.Sprintf("Got %s", got))
-}
diff --git a/gophish.go b/gophish.go
index 514323ab..6c6fbd9f 100644
--- a/gophish.go
+++ b/gophish.go
@@ -26,24 +26,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
import (
- "compress/gzip"
- "context"
"io/ioutil"
- "net/http"
"os"
- "sync"
+ "os/signal"
"gopkg.in/alecthomas/kingpin.v2"
- "github.com/NYTimes/gziphandler"
"github.com/gophish/gophish/auth"
"github.com/gophish/gophish/config"
"github.com/gophish/gophish/controllers"
log "github.com/gophish/gophish/logger"
- "github.com/gophish/gophish/mailer"
"github.com/gophish/gophish/models"
- "github.com/gophish/gophish/util"
- "github.com/gorilla/handlers"
)
var (
@@ -65,31 +58,25 @@ func main() {
kingpin.Parse()
// Load the config
- err = config.LoadConfig(*configPath)
+ conf, err := config.LoadConfig(*configPath)
// Just warn if a contact address hasn't been configured
if err != nil {
log.Fatal(err)
}
- if config.Conf.ContactAddress == "" {
+ if conf.ContactAddress == "" {
log.Warnf("No contact address has been configured.")
log.Warnf("Please consider adding a contact_address entry in your config.json")
}
config.Version = string(version)
- err = log.Setup()
+ err = log.Setup(conf)
if err != nil {
log.Fatal(err)
}
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
-
// Provide the option to disable the built-in mailer
- if !*disableMailer {
- go mailer.Mailer.Start(ctx)
- }
// Setup the global variables and settings
- err = models.Setup()
+ err = models.Setup(conf)
if err != nil {
log.Fatal(err)
}
@@ -99,39 +86,27 @@ func main() {
if err != nil {
log.Fatal(err)
}
- wg := &sync.WaitGroup{}
- wg.Add(1)
- // Start the web servers
- go func() {
- defer wg.Done()
- gzipWrapper, _ := gziphandler.NewGzipLevelHandler(gzip.BestCompression)
- adminHandler := gzipWrapper(controllers.CreateAdminRouter())
- auth.Store.Options.Secure = config.Conf.AdminConf.UseTLS
- if config.Conf.AdminConf.UseTLS { // use TLS for Admin web server if available
- err := util.CheckAndCreateSSL(config.Conf.AdminConf.CertPath, config.Conf.AdminConf.KeyPath)
- if err != nil {
- log.Fatal(err)
- }
- log.Infof("Starting admin server at https://%s", config.Conf.AdminConf.ListenURL)
- log.Info(http.ListenAndServeTLS(config.Conf.AdminConf.ListenURL, config.Conf.AdminConf.CertPath, config.Conf.AdminConf.KeyPath,
- handlers.CombinedLoggingHandler(log.Writer(), adminHandler)))
- } else {
- log.Infof("Starting admin server at http://%s", config.Conf.AdminConf.ListenURL)
- log.Info(http.ListenAndServe(config.Conf.AdminConf.ListenURL, handlers.CombinedLoggingHandler(os.Stdout, adminHandler)))
- }
- }()
- wg.Add(1)
- go func() {
- defer wg.Done()
- phishHandler := gziphandler.GzipHandler(controllers.CreatePhishingRouter())
- if config.Conf.PhishConf.UseTLS { // use TLS for Phish web server if available
- log.Infof("Starting phishing server at https://%s", config.Conf.PhishConf.ListenURL)
- log.Info(http.ListenAndServeTLS(config.Conf.PhishConf.ListenURL, config.Conf.PhishConf.CertPath, config.Conf.PhishConf.KeyPath,
- handlers.CombinedLoggingHandler(log.Writer(), phishHandler)))
- } else {
- log.Infof("Starting phishing server at http://%s", config.Conf.PhishConf.ListenURL)
- log.Fatal(http.ListenAndServe(config.Conf.PhishConf.ListenURL, handlers.CombinedLoggingHandler(os.Stdout, phishHandler)))
- }
- }()
- wg.Wait()
+
+ // Create our servers
+ adminOptions := []controllers.AdminServerOption{}
+ if *disableMailer {
+ adminOptions = append(adminOptions, controllers.WithWorker(nil))
+ }
+ adminConfig := conf.AdminConf
+ adminServer := controllers.NewAdminServer(adminConfig, adminOptions...)
+ auth.Store.Options.Secure = adminConfig.UseTLS
+
+ phishConfig := conf.PhishConf
+ phishServer := controllers.NewPhishingServer(phishConfig)
+
+ go adminServer.Start()
+ go phishServer.Start()
+
+ // Handle graceful shutdown
+ c := make(chan os.Signal, 1)
+ signal.Notify(c, os.Interrupt)
+ <-c
+ log.Info("CTRL+C Received... Gracefully shutting down servers")
+ adminServer.Shutdown()
+ phishServer.Shutdown()
}
diff --git a/logger/logger.go b/logger/logger.go
index feb7a167..1e107258 100644
--- a/logger/logger.go
+++ b/logger/logger.go
@@ -18,10 +18,10 @@ func init() {
}
// Setup configures the logger based on options in the config.json.
-func Setup() error {
+func Setup(conf *config.Config) error {
Logger.SetLevel(logrus.InfoLevel)
// Set up logging to a file if specified in the config
- logFile := config.Conf.Logging.Filename
+ logFile := conf.Logging.Filename
if logFile != "" {
f, err := os.OpenFile(logFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644)
if err != nil {
diff --git a/mailer/mailer.go b/mailer/mailer.go
index 4a2a7410..99a3f38a 100644
--- a/mailer/mailer.go
+++ b/mailer/mailer.go
@@ -29,6 +29,13 @@ func (e *ErrMaxConnectAttempts) Error() string {
return errString
}
+// Mailer is an interface that defines an object used to queue and
+// send mailer.Mail instances.
+type Mailer interface {
+ Start(ctx context.Context)
+ Queue([]Mail)
+}
+
// Sender exposes the common operations required for sending email.
type Sender interface {
Send(from string, to []string, msg io.WriterTo) error
@@ -50,27 +57,18 @@ type Mail interface {
GetDialer() (Dialer, error)
}
-// Mailer is a global instance of the mailer that can
-// be used in applications. It is the responsibility of the application
-// to call Mailer.Start()
-var Mailer *MailWorker
-
-func init() {
- Mailer = NewMailWorker()
-}
-
// MailWorker is the worker that receives slices of emails
// on a channel to send. It's assumed that every slice of emails received is meant
// to be sent to the same server.
type MailWorker struct {
- Queue chan []Mail
+ queue chan []Mail
}
// NewMailWorker returns an instance of MailWorker with the mail queue
// initialized.
func NewMailWorker() *MailWorker {
return &MailWorker{
- Queue: make(chan []Mail),
+ queue: make(chan []Mail),
}
}
@@ -81,7 +79,7 @@ func (mw *MailWorker) Start(ctx context.Context) {
select {
case <-ctx.Done():
return
- case ms := <-mw.Queue:
+ case ms := <-mw.queue:
go func(ctx context.Context, ms []Mail) {
dialer, err := ms[0].GetDialer()
if err != nil {
@@ -94,6 +92,11 @@ func (mw *MailWorker) Start(ctx context.Context) {
}
}
+// Queue sends the provided mail to the internal queue for processing.
+func (mw *MailWorker) Queue(ms []Mail) {
+ mw.queue <- ms
+}
+
// errorMail is a helper to handle erroring out a slice of Mail instances
// in the case that an unrecoverable error occurs.
func errorMail(err error, ms []Mail) {
diff --git a/mailer/mailer_test.go b/mailer/mailer_test.go
index a1573c39..a8c0450f 100644
--- a/mailer/mailer_test.go
+++ b/mailer/mailer_test.go
@@ -88,7 +88,7 @@ func (ms *MailerSuite) TestMailWorkerStart() {
messages := generateMessages(dialer)
// Send the campaign
- mw.Queue <- messages
+ mw.Queue(messages)
got := []*mockMessage{}
@@ -129,7 +129,7 @@ func (ms *MailerSuite) TestBackoff() {
messages := generateMessages(dialer)
// Send the campaign
- mw.Queue <- messages
+ mw.Queue(messages)
got := []*mockMessage{}
@@ -183,7 +183,7 @@ func (ms *MailerSuite) TestPermError() {
messages := generateMessages(dialer)
// Send the campaign
- mw.Queue <- messages
+ mw.Queue(messages)
got := []*mockMessage{}
@@ -242,7 +242,7 @@ func (ms *MailerSuite) TestUnknownError() {
messages := generateMessages(dialer)
// Send the campaign
- mw.Queue <- messages
+ mw.Queue(messages)
got := []*mockMessage{}
diff --git a/middleware/middleware.go b/middleware/middleware.go
index d0e0187b..25e0c4b8 100644
--- a/middleware/middleware.go
+++ b/middleware/middleware.go
@@ -60,8 +60,10 @@ func GetContext(handler http.Handler) http.HandlerFunc {
}
}
-func RequireAPIKey(handler http.Handler) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
+// RequireAPIKey ensures that a valid API key is set as either the api_key GET
+// parameter, or a Bearer token.
+func RequireAPIKey(handler http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
if r.Method == "OPTIONS" {
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
@@ -81,18 +83,18 @@ func RequireAPIKey(handler http.Handler) http.HandlerFunc {
}
}
if ak == "" {
- JSONError(w, 400, "API Key not set")
+ JSONError(w, http.StatusUnauthorized, "API Key not set")
return
}
u, err := models.GetUserByAPIKey(ak)
if err != nil {
- JSONError(w, 400, "Invalid API Key")
+ JSONError(w, http.StatusUnauthorized, "Invalid API Key")
return
}
r = ctx.Set(r, "user_id", u.Id)
r = ctx.Set(r, "api_key", ak)
handler.ServeHTTP(w, r)
- }
+ })
}
// RequireLogin is a simple middleware which checks to see if the user is currently logged in.
@@ -104,7 +106,7 @@ func RequireLogin(handler http.Handler) http.HandlerFunc {
} else {
q := r.URL.Query()
q.Set("next", r.URL.Path)
- http.Redirect(w, r, fmt.Sprintf("/login?%s", q.Encode()), 302)
+ http.Redirect(w, r, fmt.Sprintf("/login?%s", q.Encode()), http.StatusTemporaryRedirect)
}
}
}
diff --git a/models/campaign.go b/models/campaign.go
index 2841e564..9017f787 100644
--- a/models/campaign.go
+++ b/models/campaign.go
@@ -165,7 +165,7 @@ func (c *Campaign) AddEvent(e *Event) error {
// an error is returned. Otherwise, the attribute name is set to [Deleted],
// indicating the user deleted the attribute (template, smtp, etc.)
func (c *Campaign) getDetails() error {
- err = db.Model(c).Related(&c.Results).Error
+ err := db.Model(c).Related(&c.Results).Error
if err != nil {
log.Warnf("%s: results not found for campaign", err)
return err
@@ -402,7 +402,8 @@ func GetQueuedCampaigns(t time.Time) ([]Campaign, error) {
// PostCampaign inserts a campaign and all associated records into the database.
func PostCampaign(c *Campaign, uid int64) error {
- if err := c.Validate(); err != nil {
+ err := c.Validate()
+ if err != nil {
return err
}
// Fill in the details
@@ -514,14 +515,16 @@ func PostCampaign(c *Campaign, uid int64) error {
Reported: false,
ModifiedDate: c.CreatedDate,
}
- if r.SendDate.Before(c.CreatedDate) || r.SendDate.Equal(c.CreatedDate) {
- r.Status = STATUS_SENDING
- }
err = r.GenerateId()
if err != nil {
log.Error(err)
continue
}
+ processing := false
+ if r.SendDate.Before(c.CreatedDate) || r.SendDate.Equal(c.CreatedDate) {
+ r.Status = STATUS_SENDING
+ processing = true
+ }
err = db.Save(r).Error
if err != nil {
log.WithFields(logrus.Fields{
@@ -530,7 +533,14 @@ func PostCampaign(c *Campaign, uid int64) error {
}
c.Results = append(c.Results, *r)
log.Infof("Creating maillog for %s to send at %s\n", r.Email, sendDate)
- err = GenerateMailLog(c, r, sendDate)
+ m := &MailLog{
+ UserId: c.UserId,
+ CampaignId: c.Id,
+ RId: r.RId,
+ SendDate: sendDate,
+ Processing: processing,
+ }
+ err = db.Save(m).Error
if err != nil {
log.Error(err)
continue
diff --git a/models/campaign_test.go b/models/campaign_test.go
index 7cb3e2d3..b0f36d2b 100644
--- a/models/campaign_test.go
+++ b/models/campaign_test.go
@@ -79,3 +79,29 @@ func (s *ModelsSuite) TestCampaignDateValidation(c *check.C) {
err = campaign.Validate()
c.Assert(err, check.Equals, ErrInvalidSendByDate)
}
+
+func (s *ModelsSuite) TestLaunchCampaignMaillogStatus(c *check.C) {
+ // For the first test, ensure that campaigns created with the zero date
+ // (and therefore are set to launch immediately) have maillogs that are
+ // locked to prevent race conditions.
+ campaign := s.createCampaign(c)
+ ms, err := GetMailLogsByCampaign(campaign.Id)
+ c.Assert(err, check.Equals, nil)
+
+ for _, m := range ms {
+ c.Assert(m.Processing, check.Equals, true)
+ }
+
+ // Next, verify that campaigns scheduled in the future do not lock the
+ // maillogs so that they can be picked up by the background worker.
+ campaign = s.createCampaignDependencies(c)
+ campaign.Name = "New Campaign"
+ campaign.LaunchDate = time.Now().Add(1 * time.Hour)
+ c.Assert(PostCampaign(&campaign, campaign.UserId), check.Equals, nil)
+ ms, err = GetMailLogsByCampaign(campaign.Id)
+ c.Assert(err, check.Equals, nil)
+
+ for _, m := range ms {
+ c.Assert(m.Processing, check.Equals, false)
+ }
+}
diff --git a/models/email_request.go b/models/email_request.go
index 7fb60fc7..e4744c93 100644
--- a/models/email_request.go
+++ b/models/email_request.go
@@ -122,8 +122,8 @@ func (s *EmailRequest) Generate(msg *gomail.Message) error {
// Add the transparency headers
msg.SetHeader("X-Mailer", config.ServerName)
- if config.Conf.ContactAddress != "" {
- msg.SetHeader("X-Gophish-Contact", config.Conf.ContactAddress)
+ if conf.ContactAddress != "" {
+ msg.SetHeader("X-Gophish-Contact", conf.ContactAddress)
}
// Parse the customHeader templates
diff --git a/models/email_request_test.go b/models/email_request_test.go
index ef3a3dae..a954d988 100644
--- a/models/email_request_test.go
+++ b/models/email_request_test.go
@@ -26,7 +26,7 @@ func (s *ModelsSuite) TestEmailRequestBackoff(ch *check.C) {
}
expected := errors.New("Temporary Error")
go func() {
- err = req.Backoff(expected)
+ err := req.Backoff(expected)
ch.Assert(err, check.Equals, nil)
}()
ch.Assert(<-req.ErrorChan, check.Equals, expected)
@@ -38,7 +38,7 @@ func (s *ModelsSuite) TestEmailRequestError(ch *check.C) {
}
expected := errors.New("Temporary Error")
go func() {
- err = req.Error(expected)
+ err := req.Error(expected)
ch.Assert(err, check.Equals, nil)
}()
ch.Assert(<-req.ErrorChan, check.Equals, expected)
@@ -49,7 +49,7 @@ func (s *ModelsSuite) TestEmailRequestSuccess(ch *check.C) {
ErrorChan: make(chan error),
}
go func() {
- err = req.Success()
+ err := req.Success()
ch.Assert(err, check.Equals, nil)
}()
ch.Assert(<-req.ErrorChan, check.Equals, nil)
@@ -76,14 +76,14 @@ func (s *ModelsSuite) TestEmailRequestGenerate(ch *check.C) {
FromAddress: smtp.FromAddress,
}
- config.Conf.ContactAddress = "test@test.com"
+ s.config.ContactAddress = "test@test.com"
expectedHeaders := map[string]string{
"X-Mailer": config.ServerName,
- "X-Gophish-Contact": config.Conf.ContactAddress,
+ "X-Gophish-Contact": s.config.ContactAddress,
}
msg := gomail.NewMessage()
- err = req.Generate(msg)
+ err := req.Generate(msg)
ch.Assert(err, check.Equals, nil)
expected := &email.Email{
@@ -130,7 +130,7 @@ func (s *ModelsSuite) TestEmailRequestURLTemplating(ch *check.C) {
}
msg := gomail.NewMessage()
- err = req.Generate(msg)
+ err := req.Generate(msg)
ch.Assert(err, check.Equals, nil)
expectedURL := fmt.Sprintf("http://127.0.0.1/%s?%s=%s", req.Email, RecipientParameter, req.RId)
@@ -167,7 +167,7 @@ func (s *ModelsSuite) TestEmailRequestGenerateEmptySubject(ch *check.C) {
}
msg := gomail.NewMessage()
- err = req.Generate(msg)
+ err := req.Generate(msg)
ch.Assert(err, check.Equals, nil)
expected := &email.Email{
diff --git a/models/group.go b/models/group.go
index a7f0248a..198a0f6e 100644
--- a/models/group.go
+++ b/models/group.go
@@ -196,7 +196,7 @@ func PostGroup(g *Group) error {
return err
}
// Insert the group into the DB
- err = db.Save(g).Error
+ err := db.Save(g).Error
if err != nil {
log.Error(err)
return err
@@ -214,7 +214,7 @@ func PutGroup(g *Group) error {
}
// Fetch group's existing targets from database.
ts := []Target{}
- ts, err = GetTargets(g.Id)
+ ts, err := GetTargets(g.Id)
if err != nil {
log.WithFields(logrus.Fields{
"group_id": g.Id,
@@ -234,7 +234,7 @@ func PutGroup(g *Group) error {
}
// If the target does not exist in the group any longer, we delete it
if !tExists {
- err = db.Where("group_id=? and target_id=?", g.Id, t.Id).Delete(&GroupTarget{}).Error
+ err := db.Where("group_id=? and target_id=?", g.Id, t.Id).Delete(&GroupTarget{}).Error
if err != nil {
log.WithFields(logrus.Fields{
"email": t.Email,
@@ -286,14 +286,14 @@ func DeleteGroup(g *Group) error {
}
func insertTargetIntoGroup(t Target, gid int64) error {
- if _, err = mail.ParseAddress(t.Email); err != nil {
+ if _, err := mail.ParseAddress(t.Email); err != nil {
log.WithFields(logrus.Fields{
"email": t.Email,
}).Error("Invalid email")
return err
}
trans := db.Begin()
- err = trans.Where(t).FirstOrCreate(&t).Error
+ err := trans.Where(t).FirstOrCreate(&t).Error
if err != nil {
log.WithFields(logrus.Fields{
"email": t.Email,
diff --git a/models/maillog.go b/models/maillog.go
index 3b56a163..bbcc7a25 100644
--- a/models/maillog.go
+++ b/models/maillog.go
@@ -45,8 +45,7 @@ func GenerateMailLog(c *Campaign, r *Result, sendDate time.Time) error {
RId: r.RId,
SendDate: sendDate,
}
- err = db.Save(m).Error
- return err
+ return db.Save(m).Error
}
// Backoff sets the MailLog SendDate to be the next entry in an exponential
@@ -160,8 +159,8 @@ func (m *MailLog) Generate(msg *gomail.Message) error {
// Add the transparency headers
msg.SetHeader("X-Mailer", config.ServerName)
- if config.Conf.ContactAddress != "" {
- msg.SetHeader("X-Gophish-Contact", config.Conf.ContactAddress)
+ if conf.ContactAddress != "" {
+ msg.SetHeader("X-Gophish-Contact", conf.ContactAddress)
}
// Parse the customHeader templates
for _, header := range c.SMTP.Headers {
@@ -260,6 +259,5 @@ func LockMailLogs(ms []*MailLog, lock bool) error {
// in the database. This is intended to be called when Gophish is started
// so that any previously locked maillogs can resume processing.
func UnlockAllMailLogs() error {
- err = db.Model(&MailLog{}).Update("processing", false).Error
- return err
+ return db.Model(&MailLog{}).Update("processing", false).Error
}
diff --git a/models/maillog_test.go b/models/maillog_test.go
index c3c73a9c..5011030f 100644
--- a/models/maillog_test.go
+++ b/models/maillog_test.go
@@ -37,7 +37,13 @@ func (s *ModelsSuite) emailFromFirstMailLog(campaign Campaign, ch *check.C) *ema
func (s *ModelsSuite) TestGetQueuedMailLogs(ch *check.C) {
campaign := s.createCampaign(ch)
- ms, err := GetQueuedMailLogs(campaign.LaunchDate)
+ // By default, for campaigns with no launch date, the maillogs are set as
+ // being processed. We need to unlock them first.
+ ms, err := GetMailLogsByCampaign(campaign.Id)
+ ch.Assert(err, check.Equals, nil)
+ err = LockMailLogs(ms, false)
+ ch.Assert(err, check.Equals, nil)
+ ms, err = GetQueuedMailLogs(campaign.LaunchDate)
ch.Assert(err, check.Equals, nil)
got := make(map[string]*MailLog)
for _, m := range ms {
@@ -222,10 +228,10 @@ func (s *ModelsSuite) TestMailLogGenerate(ch *check.C) {
}
func (s *ModelsSuite) TestMailLogGenerateTransparencyHeaders(ch *check.C) {
- config.Conf.ContactAddress = "test@test.com"
+ s.config.ContactAddress = "test@test.com"
expectedHeaders := map[string]string{
"X-Mailer": config.ServerName,
- "X-Gophish-Contact": config.Conf.ContactAddress,
+ "X-Gophish-Contact": s.config.ContactAddress,
}
campaign := s.createCampaign(ch)
got := s.emailFromFirstMailLog(campaign, ch)
@@ -264,12 +270,6 @@ func (s *ModelsSuite) TestUnlockAllMailLogs(ch *check.C) {
campaign := s.createCampaign(ch)
ms, err := GetMailLogsByCampaign(campaign.Id)
ch.Assert(err, check.Equals, nil)
- for _, m := range ms {
- ch.Assert(m.Processing, check.Equals, false)
- }
- err = LockMailLogs(ms, true)
- ms, err = GetMailLogsByCampaign(campaign.Id)
- ch.Assert(err, check.Equals, nil)
for _, m := range ms {
ch.Assert(m.Processing, check.Equals, true)
}
diff --git a/models/models.go b/models/models.go
index eb7bfe2a..52f930a5 100644
--- a/models/models.go
+++ b/models/models.go
@@ -15,7 +15,7 @@ import (
)
var db *gorm.DB
-var err error
+var conf *config.Config
const (
CAMPAIGN_IN_PROGRESS string = "In progress"
@@ -78,12 +78,14 @@ func chooseDBDriver(name, openStr string) goose.DBDriver {
// Setup initializes the Conn object
// It also populates the Gophish Config object
-func Setup() error {
+func Setup(c *config.Config) error {
+ // Setup the package-scoped config
+ conf = c
// Setup the goose configuration
migrateConf := &goose.DBConf{
- MigrationsDir: config.Conf.MigrationsPath,
+ MigrationsDir: conf.MigrationsPath,
Env: "production",
- Driver: chooseDBDriver(config.Conf.DBName, config.Conf.DBPath),
+ Driver: chooseDBDriver(conf.DBName, conf.DBPath),
}
// Get the latest possible migration
latest, err := goose.GetMostRecentDBVersion(migrateConf.MigrationsDir)
@@ -92,7 +94,7 @@ func Setup() error {
return err
}
// Open our database connection
- db, err = gorm.Open(config.Conf.DBName, config.Conf.DBPath)
+ db, err = gorm.Open(conf.DBName, conf.DBPath)
db.LogMode(false)
db.SetLogger(log.Logger)
db.DB().SetMaxOpenConns(1)
diff --git a/models/models_test.go b/models/models_test.go
index 8374911c..f11e8bf2 100644
--- a/models/models_test.go
+++ b/models/models_test.go
@@ -10,15 +10,20 @@ import (
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
-type ModelsSuite struct{}
+type ModelsSuite struct {
+ config *config.Config
+}
var _ = check.Suite(&ModelsSuite{})
func (s *ModelsSuite) SetUpSuite(c *check.C) {
- config.Conf.DBName = "sqlite3"
- config.Conf.DBPath = ":memory:"
- config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/"
- err := Setup()
+ conf := &config.Config{
+ DBName: "sqlite3",
+ DBPath: ":memory:",
+ MigrationsPath: "../db/db_sqlite3/migrations/",
+ }
+ s.config = conf
+ err := Setup(conf)
if err != nil {
c.Fatalf("Failed creating database: %v", err)
}
diff --git a/models/page.go b/models/page.go
index 73aca7d5..30da2179 100644
--- a/models/page.go
+++ b/models/page.go
@@ -148,7 +148,7 @@ func PutPage(p *Page) error {
// DeletePage deletes an existing page in the database.
// An error is returned if a page with the given user id and page id is not found.
func DeletePage(id int64, uid int64) error {
- err = db.Where("user_id=?", uid).Delete(Page{Id: id}).Error
+ err := db.Where("user_id=?", uid).Delete(Page{Id: id}).Error
if err != nil {
log.Error(err)
}
diff --git a/models/smtp.go b/models/smtp.go
index ea7d4c23..f6dca63e 100644
--- a/models/smtp.go
+++ b/models/smtp.go
@@ -228,7 +228,7 @@ func PutSMTP(s *SMTP) error {
// An error is returned if a SMTP with the given user id and SMTP id is not found.
func DeleteSMTP(id int64, uid int64) error {
// Delete all custom headers
- err = db.Where("smtp_id=?", id).Delete(&Header{}).Error
+ err := db.Where("smtp_id=?", id).Delete(&Header{}).Error
if err != nil {
log.Error(err)
return err
diff --git a/models/smtp_test.go b/models/smtp_test.go
index ecf17f58..e6552a9e 100644
--- a/models/smtp_test.go
+++ b/models/smtp_test.go
@@ -15,7 +15,7 @@ func (s *ModelsSuite) TestPostSMTP(c *check.C) {
FromAddress: "Foo Bar ",
UserId: 1,
}
- err = PostSMTP(&smtp)
+ err := PostSMTP(&smtp)
c.Assert(err, check.Equals, nil)
ss, err := GetSMTPs(1)
c.Assert(err, check.Equals, nil)
@@ -28,7 +28,7 @@ func (s *ModelsSuite) TestPostSMTPNoHost(c *check.C) {
FromAddress: "Foo Bar ",
UserId: 1,
}
- err = PostSMTP(&smtp)
+ err := PostSMTP(&smtp)
c.Assert(err, check.Equals, ErrHostNotSpecified)
}
@@ -38,7 +38,7 @@ func (s *ModelsSuite) TestPostSMTPNoFrom(c *check.C) {
UserId: 1,
Host: "1.1.1.1:25",
}
- err = PostSMTP(&smtp)
+ err := PostSMTP(&smtp)
c.Assert(err, check.Equals, ErrFromAddressNotSpecified)
}
@@ -53,7 +53,7 @@ func (s *ModelsSuite) TestPostSMTPValidHeader(c *check.C) {
Header{Key: "X-Mailer", Value: "gophish"},
},
}
- err = PostSMTP(&smtp)
+ err := PostSMTP(&smtp)
c.Assert(err, check.Equals, nil)
ss, err := GetSMTPs(1)
c.Assert(err, check.Equals, nil)
diff --git a/models/template.go b/models/template.go
index a4e071ca..8f0637e9 100644
--- a/models/template.go
+++ b/models/template.go
@@ -34,10 +34,10 @@ func (t *Template) Validate() error {
case t.Text == "" && t.HTML == "":
return ErrTemplateMissingParameter
}
- if err = ValidateTemplate(t.HTML); err != nil {
+ if err := ValidateTemplate(t.HTML); err != nil {
return err
}
- if err = ValidateTemplate(t.Text); err != nil {
+ if err := ValidateTemplate(t.Text); err != nil {
return err
}
return nil
@@ -113,7 +113,7 @@ func PostTemplate(t *Template) error {
if err := t.Validate(); err != nil {
return err
}
- err = db.Save(t).Error
+ err := db.Save(t).Error
if err != nil {
log.Error(err)
return err
@@ -138,7 +138,7 @@ func PutTemplate(t *Template) error {
return err
}
// Delete all attachments, and replace with new ones
- err = db.Where("template_id=?", t.Id).Delete(&Attachment{}).Error
+ err := db.Where("template_id=?", t.Id).Delete(&Attachment{}).Error
if err != nil && err != gorm.ErrRecordNotFound {
log.Error(err)
return err
@@ -146,7 +146,7 @@ func PutTemplate(t *Template) error {
if err == gorm.ErrRecordNotFound {
err = nil
}
- for i, _ := range t.Attachments {
+ for i := range t.Attachments {
t.Attachments[i].TemplateId = t.Id
err := db.Save(&t.Attachments[i]).Error
if err != nil {
diff --git a/util/util_test.go b/util/util_test.go
index 1274071c..874e3cbf 100644
--- a/util/util_test.go
+++ b/util/util_test.go
@@ -18,10 +18,12 @@ type UtilSuite struct {
}
func (s *UtilSuite) SetupSuite() {
- config.Conf.DBName = "sqlite3"
- config.Conf.DBPath = ":memory:"
- config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/"
- err := models.Setup()
+ conf := &config.Config{
+ DBName: "sqlite3",
+ DBPath: ":memory:",
+ MigrationsPath: "../db/db_sqlite3/migrations/",
+ }
+ err := models.Setup(conf)
if err != nil {
s.T().Fatalf("Failed creating database: %v", err)
}
diff --git a/worker/worker.go b/worker/worker.go
index eaa6c6cc..7161ec4b 100644
--- a/worker/worker.go
+++ b/worker/worker.go
@@ -1,6 +1,7 @@
package worker
import (
+ "context"
"time"
log "github.com/gophish/gophish/logger"
@@ -9,18 +10,46 @@ import (
"github.com/sirupsen/logrus"
)
-// Worker is the background worker that handles watching for new campaigns and sending emails appropriately.
-type Worker struct{}
+// Worker is an interface that defines the operations needed for a background worker
+type Worker interface {
+ Start()
+ LaunchCampaign(c models.Campaign)
+ SendTestEmail(s *models.EmailRequest) error
+}
+
+// DefaultWorker is the background worker that handles watching for new campaigns and sending emails appropriately.
+type DefaultWorker struct {
+ mailer mailer.Mailer
+}
// New creates a new worker object to handle the creation of campaigns
-func New() *Worker {
- return &Worker{}
+func New(options ...func(Worker) error) (Worker, error) {
+ defaultMailer := mailer.NewMailWorker()
+ w := &DefaultWorker{
+ mailer: defaultMailer,
+ }
+ for _, opt := range options {
+ if err := opt(w); err != nil {
+ return nil, err
+ }
+ }
+ return w, nil
+}
+
+// WithMailer sets the mailer for a given worker.
+// By default, workers use a standard, default mailworker.
+func WithMailer(m mailer.Mailer) func(*DefaultWorker) error {
+ return func(w *DefaultWorker) error {
+ w.mailer = m
+ return nil
+ }
}
// Start launches the worker to poll the database every minute for any pending maillogs
// that need to be processed.
-func (w *Worker) Start() {
+func (w *DefaultWorker) Start() {
log.Info("Background Worker Started Successfully - Waiting for Campaigns")
+ go w.mailer.Start(context.Background())
for t := range time.Tick(1 * time.Minute) {
ms, err := models.GetQueuedMailLogs(t.UTC())
if err != nil {
@@ -62,14 +91,14 @@ func (w *Worker) Start() {
log.WithFields(logrus.Fields{
"num_emails": len(msc),
}).Info("Sending emails to mailer for processing")
- mailer.Mailer.Queue <- msc
+ w.mailer.Queue(msc)
}(cid, msc)
}
}
}
// LaunchCampaign starts a campaign
-func (w *Worker) LaunchCampaign(c models.Campaign) {
+func (w *DefaultWorker) LaunchCampaign(c models.Campaign) {
ms, err := models.GetMailLogsByCampaign(c.Id)
if err != nil {
log.Error(err)
@@ -89,14 +118,14 @@ func (w *Worker) LaunchCampaign(c models.Campaign) {
}
mailEntries = append(mailEntries, m)
}
- mailer.Mailer.Queue <- mailEntries
+ w.mailer.Queue(mailEntries)
}
// SendTestEmail sends a test email
-func (w *Worker) SendTestEmail(s *models.EmailRequest) error {
+func (w *DefaultWorker) SendTestEmail(s *models.EmailRequest) error {
go func() {
ms := []mailer.Mail{s}
- mailer.Mailer.Queue <- ms
+ w.mailer.Queue(ms)
}()
return <-s.ErrorChan
}
diff --git a/worker/worker_test.go b/worker/worker_test.go
index 9ad6383d..a0252604 100644
--- a/worker/worker_test.go
+++ b/worker/worker_test.go
@@ -9,17 +9,20 @@ import (
// WorkerSuite is a suite of tests to cover API related functions
type WorkerSuite struct {
suite.Suite
- ApiKey string
+ config *config.Config
}
func (s *WorkerSuite) SetupSuite() {
- config.Conf.DBName = "sqlite3"
- config.Conf.DBPath = ":memory:"
- config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/"
- err := models.Setup()
+ conf := &config.Config{
+ DBName: "sqlite3",
+ DBPath: ":memory:",
+ MigrationsPath: "../db/db_sqlite3/migrations/",
+ }
+ err := models.Setup(conf)
if err != nil {
s.T().Fatalf("Failed creating database: %v", err)
}
+ s.config = conf
s.Nil(err)
}
@@ -31,7 +34,7 @@ func (s *WorkerSuite) TearDownTest() {
}
func (s *WorkerSuite) SetupTest() {
- config.Conf.TestFlag = true
+ s.config.TestFlag = true
// Add a group
group := models.Group{Name: "Test Group"}
group.Targets = []models.Target{
@@ -73,7 +76,3 @@ func (s *WorkerSuite) SetupTest() {
models.PostCampaign(&c, c.UserId)
c.UpdateStatus(models.CAMPAIGN_EMAILS_SENT)
}
-
-func (s *WorkerSuite) TestMailSendSuccess() {
- // TODO
-}