Refactor servers (#1321)

* Refactoring servers to support custom workers and graceful shutdown.
* Refactoring workers to support custom mailers.
* Refactoring mailer to be an interface, with proper instances instead of a single global instance
* Cleaning up a few things. Locking maillogs for campaigns set to launch immediately to prevent a race condition.
* Cleaning up API middleware to be simpler
* Moving template parameters to separate struct
* Changed LoadConfig to return config object
* Cleaned up some error handling, removing uninitialized global error in models package
* Changed static file serving to use the unindexed package
pull/1323/head
Jordan Wright 2018-12-15 15:42:32 -06:00 committed by GitHub
parent 3b248d25c7
commit 47f0049c30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 554 additions and 520 deletions

View File

@ -38,9 +38,6 @@ type Config struct {
Logging LoggingConfig `json:"logging"` Logging LoggingConfig `json:"logging"`
} }
// Conf contains the initialized configuration struct
var Conf Config
// Version contains the current gophish version // Version contains the current gophish version
var Version = "" var Version = ""
@ -48,19 +45,20 @@ var Version = ""
const ServerName = "gophish" const ServerName = "gophish"
// LoadConfig loads the configuration from the specified filepath // LoadConfig loads the configuration from the specified filepath
func LoadConfig(filepath string) error { func LoadConfig(filepath string) (*Config, error) {
// Get the config file // Get the config file
configFile, err := ioutil.ReadFile(filepath) configFile, err := ioutil.ReadFile(filepath)
if err != nil { if err != nil {
return err return nil, err
} }
err = json.Unmarshal(configFile, &Conf) config := &Config{}
err = json.Unmarshal(configFile, config)
if err != nil { if err != nil {
return err return nil, err
} }
// Choosing the migrations directory based on the database used. // Choosing the migrations directory based on the database used.
Conf.MigrationsPath = Conf.MigrationsPath + Conf.DBName config.MigrationsPath = config.MigrationsPath + config.DBName
// Explicitly set the TestFlag to false to prevent config.json overrides // Explicitly set the TestFlag to false to prevent config.json overrides
Conf.TestFlag = false config.TestFlag = false
return nil return config, nil
} }

View File

@ -48,18 +48,18 @@ func (s *ConfigSuite) TestLoadConfig() {
_, err := s.ConfigFile.Write(validConfig) _, err := s.ConfigFile.Write(validConfig)
s.Nil(err) s.Nil(err)
// Load the valid config // Load the valid config
err = LoadConfig(s.ConfigFile.Name()) conf, err := LoadConfig(s.ConfigFile.Name())
s.Nil(err) s.Nil(err)
expectedConfig := Config{} expectedConfig := &Config{}
err = json.Unmarshal(validConfig, &expectedConfig) err = json.Unmarshal(validConfig, &expectedConfig)
s.Nil(err) s.Nil(err)
expectedConfig.MigrationsPath = expectedConfig.MigrationsPath + expectedConfig.DBName expectedConfig.MigrationsPath = expectedConfig.MigrationsPath + expectedConfig.DBName
expectedConfig.TestFlag = false expectedConfig.TestFlag = false
s.Equal(expectedConfig, Conf) s.Equal(expectedConfig, conf)
// Load an invalid config // Load an invalid config
err = LoadConfig("bogusfile") conf, err = LoadConfig("bogusfile")
s.NotNil(err) s.NotNil(err)
} }

View File

@ -17,23 +17,14 @@ import (
log "github.com/gophish/gophish/logger" log "github.com/gophish/gophish/logger"
"github.com/gophish/gophish/models" "github.com/gophish/gophish/models"
"github.com/gophish/gophish/util" "github.com/gophish/gophish/util"
"github.com/gophish/gophish/worker"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/jordan-wright/email" "github.com/jordan-wright/email"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// Worker is the worker that processes phishing events and updates campaigns.
var Worker *worker.Worker
func init() {
Worker = worker.New()
go Worker.Start()
}
// API (/api/reset) resets a user's API key // API (/api/reset) resets a user's API key
func API_Reset(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Reset(w http.ResponseWriter, r *http.Request) {
switch { switch {
case r.Method == "POST": case r.Method == "POST":
u := ctx.Get(r, "user").(models.User) u := ctx.Get(r, "user").(models.User)
@ -49,7 +40,7 @@ func API_Reset(w http.ResponseWriter, r *http.Request) {
// API_Campaigns returns a list of campaigns if requested via GET. // API_Campaigns returns a list of campaigns if requested via GET.
// If requested via POST, API_Campaigns creates a new campaign and returns a reference to it. // If requested via POST, API_Campaigns creates a new campaign and returns a reference to it.
func API_Campaigns(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Campaigns(w http.ResponseWriter, r *http.Request) {
switch { switch {
case r.Method == "GET": case r.Method == "GET":
cs, err := models.GetCampaigns(ctx.Get(r, "user_id").(int64)) cs, err := models.GetCampaigns(ctx.Get(r, "user_id").(int64))
@ -74,14 +65,14 @@ func API_Campaigns(w http.ResponseWriter, r *http.Request) {
// If the campaign is scheduled to launch immediately, send it to the worker. // If the campaign is scheduled to launch immediately, send it to the worker.
// Otherwise, the worker will pick it up at the scheduled time // Otherwise, the worker will pick it up at the scheduled time
if c.Status == models.CAMPAIGN_IN_PROGRESS { if c.Status == models.CAMPAIGN_IN_PROGRESS {
go Worker.LaunchCampaign(c) go as.worker.LaunchCampaign(c)
} }
JSONResponse(w, c, http.StatusCreated) JSONResponse(w, c, http.StatusCreated)
} }
} }
// API_Campaigns_Summary returns the summary for the current user's campaigns // API_Campaigns_Summary returns the summary for the current user's campaigns
func API_Campaigns_Summary(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Campaigns_Summary(w http.ResponseWriter, r *http.Request) {
switch { switch {
case r.Method == "GET": case r.Method == "GET":
cs, err := models.GetCampaignSummaries(ctx.Get(r, "user_id").(int64)) cs, err := models.GetCampaignSummaries(ctx.Get(r, "user_id").(int64))
@ -96,7 +87,7 @@ func API_Campaigns_Summary(w http.ResponseWriter, r *http.Request) {
// API_Campaigns_Id returns details about the requested campaign. If the campaign is not // API_Campaigns_Id returns details about the requested campaign. If the campaign is not
// valid, API_Campaigns_Id returns null. // valid, API_Campaigns_Id returns null.
func API_Campaigns_Id(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Campaigns_Id(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64) id, _ := strconv.ParseInt(vars["id"], 0, 64)
c, err := models.GetCampaign(id, ctx.Get(r, "user_id").(int64)) c, err := models.GetCampaign(id, ctx.Get(r, "user_id").(int64))
@ -120,7 +111,7 @@ func API_Campaigns_Id(w http.ResponseWriter, r *http.Request) {
// API_Campaigns_Id_Results returns just the results for a given campaign to // API_Campaigns_Id_Results returns just the results for a given campaign to
// significantly reduce the information returned. // significantly reduce the information returned.
func API_Campaigns_Id_Results(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Campaigns_Id_Results(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64) id, _ := strconv.ParseInt(vars["id"], 0, 64)
cr, err := models.GetCampaignResults(id, ctx.Get(r, "user_id").(int64)) cr, err := models.GetCampaignResults(id, ctx.Get(r, "user_id").(int64))
@ -136,7 +127,7 @@ func API_Campaigns_Id_Results(w http.ResponseWriter, r *http.Request) {
} }
// API_Campaigns_Id_Summary returns just the summary for a given campaign. // API_Campaigns_Id_Summary returns just the summary for a given campaign.
func API_Campaign_Id_Summary(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Campaign_Id_Summary(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64) id, _ := strconv.ParseInt(vars["id"], 0, 64)
switch { switch {
@ -157,7 +148,7 @@ func API_Campaign_Id_Summary(w http.ResponseWriter, r *http.Request) {
// API_Campaigns_Id_Complete effectively "ends" a campaign. // API_Campaigns_Id_Complete effectively "ends" a campaign.
// Future phishing emails clicked will return a simple "404" page. // Future phishing emails clicked will return a simple "404" page.
func API_Campaigns_Id_Complete(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Campaigns_Id_Complete(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64) id, _ := strconv.ParseInt(vars["id"], 0, 64)
switch { switch {
@ -173,7 +164,7 @@ func API_Campaigns_Id_Complete(w http.ResponseWriter, r *http.Request) {
// API_Groups returns a list of groups if requested via GET. // API_Groups returns a list of groups if requested via GET.
// If requested via POST, API_Groups creates a new group and returns a reference to it. // If requested via POST, API_Groups creates a new group and returns a reference to it.
func API_Groups(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Groups(w http.ResponseWriter, r *http.Request) {
switch { switch {
case r.Method == "GET": case r.Method == "GET":
gs, err := models.GetGroups(ctx.Get(r, "user_id").(int64)) gs, err := models.GetGroups(ctx.Get(r, "user_id").(int64))
@ -208,7 +199,7 @@ func API_Groups(w http.ResponseWriter, r *http.Request) {
} }
// API_Groups_Summary returns a summary of the groups owned by the current user. // API_Groups_Summary returns a summary of the groups owned by the current user.
func API_Groups_Summary(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Groups_Summary(w http.ResponseWriter, r *http.Request) {
switch { switch {
case r.Method == "GET": case r.Method == "GET":
gs, err := models.GetGroupSummaries(ctx.Get(r, "user_id").(int64)) gs, err := models.GetGroupSummaries(ctx.Get(r, "user_id").(int64))
@ -223,7 +214,7 @@ func API_Groups_Summary(w http.ResponseWriter, r *http.Request) {
// API_Groups_Id returns details about the requested group. // API_Groups_Id returns details about the requested group.
// If the group is not valid, API_Groups_Id returns null. // If the group is not valid, API_Groups_Id returns null.
func API_Groups_Id(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Groups_Id(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64) id, _ := strconv.ParseInt(vars["id"], 0, 64)
g, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64)) g, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64))
@ -261,7 +252,7 @@ func API_Groups_Id(w http.ResponseWriter, r *http.Request) {
} }
// API_Groups_Id_Summary returns a summary of the groups owned by the current user. // API_Groups_Id_Summary returns a summary of the groups owned by the current user.
func API_Groups_Id_Summary(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Groups_Id_Summary(w http.ResponseWriter, r *http.Request) {
switch { switch {
case r.Method == "GET": case r.Method == "GET":
vars := mux.Vars(r) vars := mux.Vars(r)
@ -276,7 +267,7 @@ func API_Groups_Id_Summary(w http.ResponseWriter, r *http.Request) {
} }
// API_Templates handles the functionality for the /api/templates endpoint // API_Templates handles the functionality for the /api/templates endpoint
func API_Templates(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Templates(w http.ResponseWriter, r *http.Request) {
switch { switch {
case r.Method == "GET": case r.Method == "GET":
ts, err := models.GetTemplates(ctx.Get(r, "user_id").(int64)) ts, err := models.GetTemplates(ctx.Get(r, "user_id").(int64))
@ -319,7 +310,7 @@ func API_Templates(w http.ResponseWriter, r *http.Request) {
} }
// API_Templates_Id handles the functions for the /api/templates/:id endpoint // API_Templates_Id handles the functions for the /api/templates/:id endpoint
func API_Templates_Id(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Templates_Id(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64) id, _ := strconv.ParseInt(vars["id"], 0, 64)
t, err := models.GetTemplate(id, ctx.Get(r, "user_id").(int64)) t, err := models.GetTemplate(id, ctx.Get(r, "user_id").(int64))
@ -359,7 +350,7 @@ func API_Templates_Id(w http.ResponseWriter, r *http.Request) {
} }
// API_Pages handles requests for the /api/pages/ endpoint // API_Pages handles requests for the /api/pages/ endpoint
func API_Pages(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Pages(w http.ResponseWriter, r *http.Request) {
switch { switch {
case r.Method == "GET": case r.Method == "GET":
ps, err := models.GetPages(ctx.Get(r, "user_id").(int64)) ps, err := models.GetPages(ctx.Get(r, "user_id").(int64))
@ -396,7 +387,7 @@ func API_Pages(w http.ResponseWriter, r *http.Request) {
// API_Pages_Id contains functions to handle the GET'ing, DELETE'ing, and PUT'ing // API_Pages_Id contains functions to handle the GET'ing, DELETE'ing, and PUT'ing
// of a Page object // of a Page object
func API_Pages_Id(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Pages_Id(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64) id, _ := strconv.ParseInt(vars["id"], 0, 64)
p, err := models.GetPage(id, ctx.Get(r, "user_id").(int64)) p, err := models.GetPage(id, ctx.Get(r, "user_id").(int64))
@ -436,7 +427,7 @@ func API_Pages_Id(w http.ResponseWriter, r *http.Request) {
} }
// API_SMTP handles requests for the /api/smtp/ endpoint // API_SMTP handles requests for the /api/smtp/ endpoint
func API_SMTP(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_SMTP(w http.ResponseWriter, r *http.Request) {
switch { switch {
case r.Method == "GET": case r.Method == "GET":
ss, err := models.GetSMTPs(ctx.Get(r, "user_id").(int64)) ss, err := models.GetSMTPs(ctx.Get(r, "user_id").(int64))
@ -473,7 +464,7 @@ func API_SMTP(w http.ResponseWriter, r *http.Request) {
// API_SMTP_Id contains functions to handle the GET'ing, DELETE'ing, and PUT'ing // API_SMTP_Id contains functions to handle the GET'ing, DELETE'ing, and PUT'ing
// of a SMTP object // of a SMTP object
func API_SMTP_Id(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_SMTP_Id(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)
id, _ := strconv.ParseInt(vars["id"], 0, 64) id, _ := strconv.ParseInt(vars["id"], 0, 64)
s, err := models.GetSMTP(id, ctx.Get(r, "user_id").(int64)) s, err := models.GetSMTP(id, ctx.Get(r, "user_id").(int64))
@ -518,7 +509,7 @@ func API_SMTP_Id(w http.ResponseWriter, r *http.Request) {
} }
// API_Import_Group imports a CSV of group members // API_Import_Group imports a CSV of group members
func API_Import_Group(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Import_Group(w http.ResponseWriter, r *http.Request) {
ts, err := util.ParseCSV(r) ts, err := util.ParseCSV(r)
if err != nil { if err != nil {
JSONResponse(w, models.Response{Success: false, Message: "Error parsing CSV"}, http.StatusInternalServerError) JSONResponse(w, models.Response{Success: false, Message: "Error parsing CSV"}, http.StatusInternalServerError)
@ -530,7 +521,7 @@ func API_Import_Group(w http.ResponseWriter, r *http.Request) {
// API_Import_Email allows for the importing of email. // API_Import_Email allows for the importing of email.
// Returns a Message object // Returns a Message object
func API_Import_Email(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Import_Email(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" { if r.Method != "POST" {
JSONResponse(w, models.Response{Success: false, Message: "Method not allowed"}, http.StatusBadRequest) JSONResponse(w, models.Response{Success: false, Message: "Method not allowed"}, http.StatusBadRequest)
return return
@ -579,7 +570,7 @@ func API_Import_Email(w http.ResponseWriter, r *http.Request) {
// API_Import_Site allows for the importing of HTML from a website // API_Import_Site allows for the importing of HTML from a website
// Without "include_resources" set, it will merely place a "base" tag // Without "include_resources" set, it will merely place a "base" tag
// so that all resources can be loaded relative to the given URL. // so that all resources can be loaded relative to the given URL.
func API_Import_Site(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Import_Site(w http.ResponseWriter, r *http.Request) {
cr := cloneRequest{} cr := cloneRequest{}
if r.Method != "POST" { if r.Method != "POST" {
JSONResponse(w, models.Response{Success: false, Message: "Method not allowed"}, http.StatusBadRequest) JSONResponse(w, models.Response{Success: false, Message: "Method not allowed"}, http.StatusBadRequest)
@ -637,7 +628,7 @@ func API_Import_Site(w http.ResponseWriter, r *http.Request) {
// API_Send_Test_Email sends a test email using the template name // API_Send_Test_Email sends a test email using the template name
// and Target given. // and Target given.
func API_Send_Test_Email(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) API_Send_Test_Email(w http.ResponseWriter, r *http.Request) {
s := &models.EmailRequest{ s := &models.EmailRequest{
ErrorChan: make(chan error), ErrorChan: make(chan error),
UserId: ctx.Get(r, "user_id").(int64), UserId: ctx.Get(r, "user_id").(int64),
@ -735,7 +726,7 @@ func API_Send_Test_Email(w http.ResponseWriter, r *http.Request) {
} }
} }
// Send the test email // Send the test email
err = Worker.SendTestEmail(s) err = as.worker.SendTestEmail(s)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusInternalServerError) JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusInternalServerError)

View File

@ -11,7 +11,6 @@ import (
"github.com/gophish/gophish/config" "github.com/gophish/gophish/config"
"github.com/gophish/gophish/models" "github.com/gophish/gophish/models"
"github.com/gorilla/handlers"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
@ -19,33 +18,35 @@ import (
type ControllersSuite struct { type ControllersSuite struct {
suite.Suite suite.Suite
ApiKey string ApiKey string
config *config.Config
adminServer *httptest.Server
phishServer *httptest.Server
} }
// as is the Admin Server for our API calls
var as *httptest.Server = httptest.NewUnstartedServer(handlers.CombinedLoggingHandler(os.Stdout, CreateAdminRouter()))
// ps is the Phishing Server
var ps *httptest.Server = httptest.NewUnstartedServer(handlers.CombinedLoggingHandler(os.Stdout, CreatePhishingRouter()))
func (s *ControllersSuite) SetupSuite() { func (s *ControllersSuite) SetupSuite() {
config.Conf.DBName = "sqlite3" conf := &config.Config{
config.Conf.DBPath = ":memory:" DBName: "sqlite3",
config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/" DBPath: ":memory:",
err := models.Setup() MigrationsPath: "../db/db_sqlite3/migrations/",
}
err := models.Setup(conf)
if err != nil { if err != nil {
s.T().Fatalf("Failed creating database: %v", err) s.T().Fatalf("Failed creating database: %v", err)
} }
s.config = conf
s.Nil(err) s.Nil(err)
// Setup the admin server for use in testing // Setup the admin server for use in testing
as.Config.Addr = config.Conf.AdminConf.ListenURL s.adminServer = httptest.NewUnstartedServer(NewAdminServer(s.config.AdminConf).server.Handler)
as.Start() s.adminServer.Config.Addr = s.config.AdminConf.ListenURL
s.adminServer.Start()
// Get the API key to use for these tests // Get the API key to use for these tests
u, err := models.GetUser(1) u, err := models.GetUser(1)
s.Nil(err) s.Nil(err)
s.ApiKey = u.ApiKey s.ApiKey = u.ApiKey
// Start the phishing server // Start the phishing server
ps.Config.Addr = config.Conf.PhishConf.ListenURL s.phishServer = httptest.NewUnstartedServer(NewPhishingServer(s.config.PhishConf).server.Handler)
ps.Start() s.phishServer.Config.Addr = s.config.PhishConf.ListenURL
s.phishServer.Start()
// Move our cwd up to the project root for help with resolving // Move our cwd up to the project root for help with resolving
// static assets // static assets
err = os.Chdir("../") err = os.Chdir("../")
@ -103,21 +104,21 @@ func (s *ControllersSuite) SetupTest() {
} }
func (s *ControllersSuite) TestRequireAPIKey() { func (s *ControllersSuite) TestRequireAPIKey() {
resp, err := http.Post(fmt.Sprintf("%s/api/import/site", as.URL), "application/json", nil) resp, err := http.Post(fmt.Sprintf("%s/api/import/site", s.adminServer.URL), "application/json", nil)
s.Nil(err) s.Nil(err)
defer resp.Body.Close() defer resp.Body.Close()
s.Equal(resp.StatusCode, http.StatusBadRequest) s.Equal(resp.StatusCode, http.StatusUnauthorized)
} }
func (s *ControllersSuite) TestInvalidAPIKey() { func (s *ControllersSuite) TestInvalidAPIKey() {
resp, err := http.Get(fmt.Sprintf("%s/api/groups/?api_key=%s", as.URL, "bogus-api-key")) resp, err := http.Get(fmt.Sprintf("%s/api/groups/?api_key=%s", s.adminServer.URL, "bogus-api-key"))
s.Nil(err) s.Nil(err)
defer resp.Body.Close() defer resp.Body.Close()
s.Equal(resp.StatusCode, http.StatusBadRequest) s.Equal(resp.StatusCode, http.StatusUnauthorized)
} }
func (s *ControllersSuite) TestBearerToken() { func (s *ControllersSuite) TestBearerToken() {
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/groups/", as.URL), nil) req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/groups/", s.adminServer.URL), nil)
s.Nil(err) s.Nil(err)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.ApiKey)) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.ApiKey))
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
@ -133,7 +134,7 @@ func (s *ControllersSuite) TestSiteImportBaseHref() {
})) }))
hr := fmt.Sprintf("<html><head><base href=\"%s\"/></head><body><img src=\"/test.png\"/>\n</body></html>", ts.URL) hr := fmt.Sprintf("<html><head><base href=\"%s\"/></head><body><img src=\"/test.png\"/>\n</body></html>", ts.URL)
defer ts.Close() defer ts.Close()
resp, err := http.Post(fmt.Sprintf("%s/api/import/site?api_key=%s", as.URL, s.ApiKey), "application/json", resp, err := http.Post(fmt.Sprintf("%s/api/import/site?api_key=%s", s.adminServer.URL, s.ApiKey), "application/json",
bytes.NewBuffer([]byte(fmt.Sprintf(` bytes.NewBuffer([]byte(fmt.Sprintf(`
{ {
"url" : "%s", "url" : "%s",
@ -150,8 +151,8 @@ func (s *ControllersSuite) TestSiteImportBaseHref() {
func (s *ControllersSuite) TearDownSuite() { func (s *ControllersSuite) TearDownSuite() {
// Tear down the admin and phishing servers // Tear down the admin and phishing servers
as.Close() s.adminServer.Close()
ps.Close() s.phishServer.Close()
} }
func TestControllerSuite(t *testing.T) { func TestControllerSuite(t *testing.T) {

View File

@ -1,6 +1,8 @@
package controllers package controllers
import ( import (
"compress/gzip"
"context"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@ -8,11 +10,15 @@ import (
"strings" "strings"
"time" "time"
"github.com/NYTimes/gziphandler"
"github.com/gophish/gophish/config" "github.com/gophish/gophish/config"
ctx "github.com/gophish/gophish/context" ctx "github.com/gophish/gophish/context"
log "github.com/gophish/gophish/logger" log "github.com/gophish/gophish/logger"
"github.com/gophish/gophish/models" "github.com/gophish/gophish/models"
"github.com/gophish/gophish/util"
"github.com/gorilla/handlers"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/jordan-wright/unindexed"
) )
// ErrInvalidRequest is thrown when a request with an invalid structure is // ErrInvalidRequest is thrown when a request with an invalid structure is
@ -35,22 +41,91 @@ type TransparencyResponse struct {
// to return a transparency response. // to return a transparency response.
const TransparencySuffix = "+" const TransparencySuffix = "+"
// CreatePhishingRouter creates the router that handles phishing connections. // PhishingServerOption is a functional option that is used to configure the
func CreatePhishingRouter() http.Handler { // the phishing server
router := mux.NewRouter() type PhishingServerOption func(*PhishingServer)
fileServer := http.FileServer(UnindexedFileSystem{http.Dir("./static/endpoint/")})
router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", fileServer)) // PhishingServer is an HTTP server that implements the campaign event
router.HandleFunc("/track", PhishTracker) // handlers, such as email open tracking, click tracking, and more.
router.HandleFunc("/robots.txt", RobotsHandler) type PhishingServer struct {
router.HandleFunc("/{path:.*}/track", PhishTracker) server *http.Server
router.HandleFunc("/{path:.*}/report", PhishReporter) config config.PhishServer
router.HandleFunc("/report", PhishReporter) contactAddress string
router.HandleFunc("/{path:.*}", PhishHandler)
return router
} }
// PhishTracker tracks emails as they are opened, updating the status for the given Result // NewPhishingServer returns a new instance of the phishing server with
func PhishTracker(w http.ResponseWriter, r *http.Request) { // provided options applied.
func NewPhishingServer(config config.PhishServer, options ...PhishingServerOption) *PhishingServer {
defaultServer := &http.Server{
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
Addr: config.ListenURL,
}
ps := &PhishingServer{
server: defaultServer,
config: config,
}
for _, opt := range options {
opt(ps)
}
ps.registerRoutes()
return ps
}
// WithContactAddress sets the contact address used by the transparency
// handlers
func WithContactAddress(addr string) PhishingServerOption {
return func(ps *PhishingServer) {
ps.contactAddress = addr
}
}
// Start launches the phishing server, listening on the configured address.
func (ps *PhishingServer) Start() error {
if ps.config.UseTLS {
err := util.CheckAndCreateSSL(ps.config.CertPath, ps.config.KeyPath)
if err != nil {
log.Fatal(err)
return err
}
log.Infof("Starting phishing server at https://%s", ps.config.ListenURL)
return ps.server.ListenAndServeTLS(ps.config.CertPath, ps.config.KeyPath)
}
// If TLS isn't configured, just listen on HTTP
log.Infof("Starting phishing server at http://%s", ps.config.ListenURL)
return ps.server.ListenAndServe()
}
// Shutdown attempts to gracefully shutdown the server.
func (ps *PhishingServer) Shutdown() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
return ps.server.Shutdown(ctx)
}
// CreatePhishingRouter creates the router that handles phishing connections.
func (ps *PhishingServer) registerRoutes() {
router := mux.NewRouter()
fileServer := http.FileServer(unindexed.Dir("./static/endpoint/"))
router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", fileServer))
router.HandleFunc("/track", ps.TrackHandler)
router.HandleFunc("/robots.txt", ps.RobotsHandler)
router.HandleFunc("/{path:.*}/track", ps.TrackHandler)
router.HandleFunc("/{path:.*}/report", ps.ReportHandler)
router.HandleFunc("/report", ps.ReportHandler)
router.HandleFunc("/{path:.*}", ps.PhishHandler)
// Setup GZIP compression
gzipWrapper, _ := gziphandler.NewGzipLevelHandler(gzip.BestCompression)
phishHandler := gzipWrapper(router)
// Setup logging
phishHandler = handlers.CombinedLoggingHandler(log.Writer(), phishHandler)
ps.server.Handler = router
}
// TrackHandler tracks emails as they are opened, updating the status for the given Result
func (ps *PhishingServer) TrackHandler(w http.ResponseWriter, r *http.Request) {
err, r := setupContext(r) err, r := setupContext(r)
if err != nil { if err != nil {
// Log the error if it wasn't something we can safely ignore // Log the error if it wasn't something we can safely ignore
@ -71,7 +146,7 @@ func PhishTracker(w http.ResponseWriter, r *http.Request) {
// Check for a transparency request // Check for a transparency request
if strings.HasSuffix(rid, TransparencySuffix) { if strings.HasSuffix(rid, TransparencySuffix) {
TransparencyHandler(w, r) ps.TransparencyHandler(w, r)
return return
} }
@ -82,8 +157,8 @@ func PhishTracker(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "static/images/pixel.png") http.ServeFile(w, r, "static/images/pixel.png")
} }
// PhishReporter tracks emails as they are reported, updating the status for the given Result // ReportHandler tracks emails as they are reported, updating the status for the given Result
func PhishReporter(w http.ResponseWriter, r *http.Request) { func (ps *PhishingServer) ReportHandler(w http.ResponseWriter, r *http.Request) {
err, r := setupContext(r) err, r := setupContext(r)
if err != nil { if err != nil {
// Log the error if it wasn't something we can safely ignore // Log the error if it wasn't something we can safely ignore
@ -104,7 +179,7 @@ func PhishReporter(w http.ResponseWriter, r *http.Request) {
// Check for a transparency request // Check for a transparency request
if strings.HasSuffix(rid, TransparencySuffix) { if strings.HasSuffix(rid, TransparencySuffix) {
TransparencyHandler(w, r) ps.TransparencyHandler(w, r)
return return
} }
@ -117,7 +192,7 @@ func PhishReporter(w http.ResponseWriter, r *http.Request) {
// PhishHandler handles incoming client connections and registers the associated actions performed // PhishHandler handles incoming client connections and registers the associated actions performed
// (such as clicked link, etc.) // (such as clicked link, etc.)
func PhishHandler(w http.ResponseWriter, r *http.Request) { func (ps *PhishingServer) PhishHandler(w http.ResponseWriter, r *http.Request) {
err, r := setupContext(r) err, r := setupContext(r)
if err != nil { if err != nil {
// Log the error if it wasn't something we can safely ignore // Log the error if it wasn't something we can safely ignore
@ -152,7 +227,7 @@ func PhishHandler(w http.ResponseWriter, r *http.Request) {
// Check for a transparency request // Check for a transparency request
if strings.HasSuffix(rid, TransparencySuffix) { if strings.HasSuffix(rid, TransparencySuffix) {
TransparencyHandler(w, r) ps.TransparencyHandler(w, r)
return return
} }
@ -211,23 +286,24 @@ func renderPhishResponse(w http.ResponseWriter, r *http.Request, ptx models.Phis
} }
// RobotsHandler prevents search engines, etc. from indexing phishing materials // RobotsHandler prevents search engines, etc. from indexing phishing materials
func RobotsHandler(w http.ResponseWriter, r *http.Request) { func (ps *PhishingServer) RobotsHandler(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "User-agent: *\nDisallow: /") fmt.Fprintln(w, "User-agent: *\nDisallow: /")
} }
// TransparencyHandler returns a TransparencyResponse for the provided result // TransparencyHandler returns a TransparencyResponse for the provided result
// and campaign. // and campaign.
func TransparencyHandler(w http.ResponseWriter, r *http.Request) { func (ps *PhishingServer) TransparencyHandler(w http.ResponseWriter, r *http.Request) {
rs := ctx.Get(r, "result").(models.Result) rs := ctx.Get(r, "result").(models.Result)
tr := &TransparencyResponse{ tr := &TransparencyResponse{
Server: config.ServerName, Server: config.ServerName,
SendDate: rs.SendDate, SendDate: rs.SendDate,
ContactAddress: config.Conf.ContactAddress, ContactAddress: ps.contactAddress,
} }
JSONResponse(w, tr, http.StatusOK) JSONResponse(w, tr, http.StatusOK)
} }
// setupContext handles some of the administrative work around receiving a new request, such as checking the result ID, the campaign, etc. // setupContext handles some of the administrative work around receiving a new
// request, such as checking the result ID, the campaign, etc.
func setupContext(r *http.Request) (error, *http.Request) { func setupContext(r *http.Request) (error, *http.Request) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {

View File

@ -38,7 +38,7 @@ func (s *ControllersSuite) getFirstEmailRequest() models.EmailRequest {
} }
func (s *ControllersSuite) openEmail(rid string) { func (s *ControllersSuite) openEmail(rid string) {
resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", ps.URL, models.RecipientParameter, rid)) resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err) s.Nil(err)
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
@ -49,19 +49,19 @@ func (s *ControllersSuite) openEmail(rid string) {
} }
func (s *ControllersSuite) reportedEmail(rid string) { func (s *ControllersSuite) reportedEmail(rid string) {
resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", ps.URL, models.RecipientParameter, rid)) resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err) s.Nil(err)
s.Equal(resp.StatusCode, http.StatusNoContent) s.Equal(resp.StatusCode, http.StatusNoContent)
} }
func (s *ControllersSuite) reportEmail404(rid string) { func (s *ControllersSuite) reportEmail404(rid string) {
resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", ps.URL, models.RecipientParameter, rid)) resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err) s.Nil(err)
s.Equal(resp.StatusCode, http.StatusNotFound) s.Equal(resp.StatusCode, http.StatusNotFound)
} }
func (s *ControllersSuite) openEmail404(rid string) { func (s *ControllersSuite) openEmail404(rid string) {
resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", ps.URL, models.RecipientParameter, rid)) resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err) s.Nil(err)
defer resp.Body.Close() defer resp.Body.Close()
s.Nil(err) s.Nil(err)
@ -69,7 +69,7 @@ func (s *ControllersSuite) openEmail404(rid string) {
} }
func (s *ControllersSuite) clickLink(rid string, expectedHTML string) { func (s *ControllersSuite) clickLink(rid string, expectedHTML string) {
resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", ps.URL, models.RecipientParameter, rid)) resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err) s.Nil(err)
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
@ -79,7 +79,7 @@ func (s *ControllersSuite) clickLink(rid string, expectedHTML string) {
} }
func (s *ControllersSuite) clickLink404(rid string) { func (s *ControllersSuite) clickLink404(rid string) {
resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", ps.URL, models.RecipientParameter, rid)) resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
s.Nil(err) s.Nil(err)
defer resp.Body.Close() defer resp.Body.Close()
s.Nil(err) s.Nil(err)
@ -87,14 +87,14 @@ func (s *ControllersSuite) clickLink404(rid string) {
} }
func (s *ControllersSuite) transparencyRequest(r models.Result, rid, path string) { func (s *ControllersSuite) transparencyRequest(r models.Result, rid, path string) {
resp, err := http.Get(fmt.Sprintf("%s%s?%s=%s", ps.URL, path, models.RecipientParameter, rid)) resp, err := http.Get(fmt.Sprintf("%s%s?%s=%s", s.phishServer.URL, path, models.RecipientParameter, rid))
s.Nil(err) s.Nil(err)
defer resp.Body.Close() defer resp.Body.Close()
s.Equal(resp.StatusCode, http.StatusOK) s.Equal(resp.StatusCode, http.StatusOK)
tr := &TransparencyResponse{} tr := &TransparencyResponse{}
err = json.NewDecoder(resp.Body).Decode(tr) err = json.NewDecoder(resp.Body).Decode(tr)
s.Nil(err) s.Nil(err)
s.Equal(tr.ContactAddress, config.Conf.ContactAddress) s.Equal(tr.ContactAddress, s.config.ContactAddress)
s.Equal(tr.SendDate, r.SendDate) s.Equal(tr.SendDate, r.SendDate)
s.Equal(tr.Server, config.ServerName) s.Equal(tr.Server, config.ServerName)
} }
@ -146,11 +146,11 @@ func (s *ControllersSuite) TestClickedPhishingLinkAfterOpen() {
} }
func (s *ControllersSuite) TestNoRecipientID() { func (s *ControllersSuite) TestNoRecipientID() {
resp, err := http.Get(fmt.Sprintf("%s/track", ps.URL)) resp, err := http.Get(fmt.Sprintf("%s/track", s.phishServer.URL))
s.Nil(err) s.Nil(err)
s.Equal(resp.StatusCode, http.StatusNotFound) s.Equal(resp.StatusCode, http.StatusNotFound)
resp, err = http.Get(ps.URL) resp, err = http.Get(s.phishServer.URL)
s.Nil(err) s.Nil(err)
s.Equal(resp.StatusCode, http.StatusNotFound) s.Equal(resp.StatusCode, http.StatusNotFound)
} }
@ -183,7 +183,7 @@ func (s *ControllersSuite) TestCompletedCampaignClick() {
func (s *ControllersSuite) TestRobotsHandler() { func (s *ControllersSuite) TestRobotsHandler() {
expected := []byte("User-agent: *\nDisallow: /\n") expected := []byte("User-agent: *\nDisallow: /\n")
resp, err := http.Get(fmt.Sprintf("%s/robots.txt", ps.URL)) resp, err := http.Get(fmt.Sprintf("%s/robots.txt", s.phishServer.URL))
s.Nil(err) s.Nil(err)
s.Equal(resp.StatusCode, http.StatusOK) s.Equal(resp.StatusCode, http.StatusOK)
defer resp.Body.Close() defer resp.Body.Close()
@ -259,7 +259,7 @@ func (s *ControllersSuite) TestRedirectTemplating() {
}, },
} }
result := campaign.Results[0] result := campaign.Results[0]
resp, err := client.PostForm(fmt.Sprintf("%s/?%s=%s", ps.URL, models.RecipientParameter, result.RId), url.Values{"username": {"test"}, "password": {"test"}}) resp, err := client.PostForm(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, result.RId), url.Values{"username": {"test"}, "password": {"test"}})
s.Nil(err) s.Nil(err)
defer resp.Body.Close() defer resp.Body.Close()
s.Equal(http.StatusFound, resp.StatusCode) s.Equal(http.StatusFound, resp.StatusCode)

View File

@ -1,72 +1,153 @@
package controllers package controllers
import ( import (
"fmt" "compress/gzip"
"context"
"html/template" "html/template"
"net/http" "net/http"
"net/url" "net/url"
"time"
"github.com/NYTimes/gziphandler"
"github.com/gophish/gophish/auth" "github.com/gophish/gophish/auth"
"github.com/gophish/gophish/config" "github.com/gophish/gophish/config"
ctx "github.com/gophish/gophish/context" ctx "github.com/gophish/gophish/context"
log "github.com/gophish/gophish/logger" log "github.com/gophish/gophish/logger"
mid "github.com/gophish/gophish/middleware" mid "github.com/gophish/gophish/middleware"
"github.com/gophish/gophish/models" "github.com/gophish/gophish/models"
"github.com/gophish/gophish/util"
"github.com/gophish/gophish/worker"
"github.com/gorilla/csrf" "github.com/gorilla/csrf"
"github.com/gorilla/handlers"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/jordan-wright/unindexed"
) )
// CreateAdminRouter creates the routes for handling requests to the web interface. // AdminServerOption is a functional option that is used to configure the
// admin server
type AdminServerOption func(*AdminServer)
// AdminServer is an HTTP server that implements the administrative Gophish
// handlers, including the dashboard and REST API.
type AdminServer struct {
server *http.Server
worker worker.Worker
config config.AdminServer
}
// WithWorker is an option that sets the background worker.
func WithWorker(w worker.Worker) AdminServerOption {
return func(as *AdminServer) {
as.worker = w
}
}
// NewAdminServer returns a new instance of the AdminServer with the
// provided config and options applied.
func NewAdminServer(config config.AdminServer, options ...AdminServerOption) *AdminServer {
defaultWorker, _ := worker.New()
defaultServer := &http.Server{
ReadTimeout: 10 * time.Second,
Addr: config.ListenURL,
}
as := &AdminServer{
worker: defaultWorker,
server: defaultServer,
config: config,
}
for _, opt := range options {
opt(as)
}
as.registerRoutes()
return as
}
// Start launches the admin server, listening on the configured address.
func (as *AdminServer) Start() error {
if as.worker != nil {
go as.worker.Start()
}
if as.config.UseTLS {
err := util.CheckAndCreateSSL(as.config.CertPath, as.config.KeyPath)
if err != nil {
log.Fatal(err)
return err
}
log.Infof("Starting admin server at https://%s", as.config.ListenURL)
return as.server.ListenAndServeTLS(as.config.CertPath, as.config.KeyPath)
}
// If TLS isn't configured, just listen on HTTP
log.Infof("Starting admin server at http://%s", as.config.ListenURL)
return as.server.ListenAndServe()
}
// Shutdown attempts to gracefully shutdown the server.
func (as *AdminServer) Shutdown() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
return as.server.Shutdown(ctx)
}
// SetupAdminRoutes creates the routes for handling requests to the web interface.
// This function returns an http.Handler to be used in http.ListenAndServe(). // This function returns an http.Handler to be used in http.ListenAndServe().
func CreateAdminRouter() http.Handler { func (as *AdminServer) registerRoutes() {
router := mux.NewRouter() router := mux.NewRouter()
// Base Front-end routes // Base Front-end routes
router.HandleFunc("/", Use(Base, mid.RequireLogin)) router.HandleFunc("/", Use(as.Base, mid.RequireLogin))
router.HandleFunc("/login", Login) router.HandleFunc("/login", as.Login)
router.HandleFunc("/logout", Use(Logout, mid.RequireLogin)) router.HandleFunc("/logout", Use(as.Logout, mid.RequireLogin))
router.HandleFunc("/campaigns", Use(Campaigns, mid.RequireLogin)) router.HandleFunc("/campaigns", Use(as.Campaigns, mid.RequireLogin))
router.HandleFunc("/campaigns/{id:[0-9]+}", Use(CampaignID, mid.RequireLogin)) router.HandleFunc("/campaigns/{id:[0-9]+}", Use(as.CampaignID, mid.RequireLogin))
router.HandleFunc("/templates", Use(Templates, mid.RequireLogin)) router.HandleFunc("/templates", Use(as.Templates, mid.RequireLogin))
router.HandleFunc("/users", Use(Users, mid.RequireLogin)) router.HandleFunc("/users", Use(as.Users, mid.RequireLogin))
router.HandleFunc("/landing_pages", Use(LandingPages, mid.RequireLogin)) router.HandleFunc("/landing_pages", Use(as.LandingPages, mid.RequireLogin))
router.HandleFunc("/sending_profiles", Use(SendingProfiles, mid.RequireLogin)) router.HandleFunc("/sending_profiles", Use(as.SendingProfiles, mid.RequireLogin))
router.HandleFunc("/register", Use(Register, mid.RequireLogin)) router.HandleFunc("/register", Use(as.Register, mid.RequireLogin))
router.HandleFunc("/settings", Use(Settings, mid.RequireLogin)) router.HandleFunc("/settings", Use(as.Settings, mid.RequireLogin))
// Create the API routes // Create the API routes
api := router.PathPrefix("/api").Subrouter() api := router.PathPrefix("/api").Subrouter()
api = api.StrictSlash(true) api = api.StrictSlash(true)
api.HandleFunc("/reset", Use(API_Reset, mid.RequireAPIKey)) api.Use(mid.RequireAPIKey)
api.HandleFunc("/campaigns/", Use(API_Campaigns, mid.RequireAPIKey)) api.HandleFunc("/reset", as.API_Reset)
api.HandleFunc("/campaigns/summary", Use(API_Campaigns_Summary, mid.RequireAPIKey)) api.HandleFunc("/campaigns/", as.API_Campaigns)
api.HandleFunc("/campaigns/{id:[0-9]+}", Use(API_Campaigns_Id, mid.RequireAPIKey)) api.HandleFunc("/campaigns/summary", as.API_Campaigns_Summary)
api.HandleFunc("/campaigns/{id:[0-9]+}/results", Use(API_Campaigns_Id_Results, mid.RequireAPIKey)) api.HandleFunc("/campaigns/{id:[0-9]+}", as.API_Campaigns_Id)
api.HandleFunc("/campaigns/{id:[0-9]+}/summary", Use(API_Campaign_Id_Summary, mid.RequireAPIKey)) api.HandleFunc("/campaigns/{id:[0-9]+}/results", as.API_Campaigns_Id_Results)
api.HandleFunc("/campaigns/{id:[0-9]+}/complete", Use(API_Campaigns_Id_Complete, mid.RequireAPIKey)) api.HandleFunc("/campaigns/{id:[0-9]+}/summary", as.API_Campaign_Id_Summary)
api.HandleFunc("/groups/", Use(API_Groups, mid.RequireAPIKey)) api.HandleFunc("/campaigns/{id:[0-9]+}/complete", as.API_Campaigns_Id_Complete)
api.HandleFunc("/groups/summary", Use(API_Groups_Summary, mid.RequireAPIKey)) api.HandleFunc("/groups/", as.API_Groups)
api.HandleFunc("/groups/{id:[0-9]+}", Use(API_Groups_Id, mid.RequireAPIKey)) api.HandleFunc("/groups/summary", as.API_Groups_Summary)
api.HandleFunc("/groups/{id:[0-9]+}/summary", Use(API_Groups_Id_Summary, mid.RequireAPIKey)) api.HandleFunc("/groups/{id:[0-9]+}", as.API_Groups_Id)
api.HandleFunc("/templates/", Use(API_Templates, mid.RequireAPIKey)) api.HandleFunc("/groups/{id:[0-9]+}/summary", as.API_Groups_Id_Summary)
api.HandleFunc("/templates/{id:[0-9]+}", Use(API_Templates_Id, mid.RequireAPIKey)) api.HandleFunc("/templates/", as.API_Templates)
api.HandleFunc("/pages/", Use(API_Pages, mid.RequireAPIKey)) api.HandleFunc("/templates/{id:[0-9]+}", as.API_Templates_Id)
api.HandleFunc("/pages/{id:[0-9]+}", Use(API_Pages_Id, mid.RequireAPIKey)) api.HandleFunc("/pages/", as.API_Pages)
api.HandleFunc("/smtp/", Use(API_SMTP, mid.RequireAPIKey)) api.HandleFunc("/pages/{id:[0-9]+}", as.API_Pages_Id)
api.HandleFunc("/smtp/{id:[0-9]+}", Use(API_SMTP_Id, mid.RequireAPIKey)) api.HandleFunc("/smtp/", as.API_SMTP)
api.HandleFunc("/util/send_test_email", Use(API_Send_Test_Email, mid.RequireAPIKey)) api.HandleFunc("/smtp/{id:[0-9]+}", as.API_SMTP_Id)
api.HandleFunc("/import/group", Use(API_Import_Group, mid.RequireAPIKey)) api.HandleFunc("/util/send_test_email", as.API_Send_Test_Email)
api.HandleFunc("/import/email", Use(API_Import_Email, mid.RequireAPIKey)) api.HandleFunc("/import/group", as.API_Import_Group)
api.HandleFunc("/import/site", Use(API_Import_Site, mid.RequireAPIKey)) api.HandleFunc("/import/email", as.API_Import_Email)
api.HandleFunc("/import/site", as.API_Import_Site)
// Setup static file serving // Setup static file serving
router.PathPrefix("/").Handler(http.FileServer(UnindexedFileSystem{http.Dir("./static/")})) router.PathPrefix("/").Handler(http.FileServer(unindexed.Dir("./static/")))
// Setup CSRF Protection // Setup CSRF Protection
csrfHandler := csrf.Protect([]byte(auth.GenerateSecureKey()), csrfHandler := csrf.Protect([]byte(auth.GenerateSecureKey()),
csrf.FieldName("csrf_token"), csrf.FieldName("csrf_token"),
csrf.Secure(config.Conf.AdminConf.UseTLS)) csrf.Secure(as.config.UseTLS))
csrfRouter := csrfHandler(router) adminHandler := csrfHandler(router)
return Use(csrfRouter.ServeHTTP, mid.CSRFExceptions, mid.GetContext) adminHandler = Use(adminHandler.ServeHTTP, mid.CSRFExceptions, mid.GetContext)
// Setup GZIP compression
gzipWrapper, _ := gziphandler.NewGzipLevelHandler(gzip.BestCompression)
adminHandler = gzipWrapper(adminHandler)
// Setup logging
adminHandler = handlers.CombinedLoggingHandler(log.Writer(), adminHandler)
as.server.Handler = adminHandler
} }
// Use allows us to stack middleware to process the request // Use allows us to stack middleware to process the request
@ -78,16 +159,29 @@ func Use(handler http.HandlerFunc, mid ...func(http.Handler) http.HandlerFunc) h
return handler return handler
} }
// Register creates a new user type templateParams struct {
func Register(w http.ResponseWriter, r *http.Request) {
// If it is a post request, attempt to register the account
// Now that we are all registered, we can log the user in
params := struct {
Title string Title string
Flashes []interface{} Flashes []interface{}
User models.User User models.User
Token string Token string
}{Title: "Register", Token: csrf.Token(r)} Version string
}
// newTemplateParams returns the default template parameters for a user and
// the CSRF token.
func newTemplateParams(r *http.Request) templateParams {
return templateParams{
Token: csrf.Token(r),
User: ctx.Get(r, "user").(models.User),
Version: config.Version,
}
}
// Register creates a new user
func (as *AdminServer) Register(w http.ResponseWriter, r *http.Request) {
// If it is a post request, attempt to register the account
// Now that we are all registered, we can log the user in
params := templateParams{Title: "Register", Token: csrf.Token(r)}
session := ctx.Get(r, "session").(*sessions.Session) session := ctx.Get(r, "session").(*sessions.Session)
switch { switch {
case r.Method == "GET": case r.Method == "GET":
@ -120,99 +214,60 @@ func Register(w http.ResponseWriter, r *http.Request) {
} }
// Base handles the default path and template execution // Base handles the default path and template execution
func Base(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) Base(w http.ResponseWriter, r *http.Request) {
params := struct { params := newTemplateParams(r)
User models.User params.Title = "Dashboard"
Title string
Flashes []interface{}
Token string
}{Title: "Dashboard", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
getTemplate(w, "dashboard").ExecuteTemplate(w, "base", params) getTemplate(w, "dashboard").ExecuteTemplate(w, "base", params)
} }
// Campaigns handles the default path and template execution // Campaigns handles the default path and template execution
func Campaigns(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) Campaigns(w http.ResponseWriter, r *http.Request) {
// Example of using session - will be removed. params := newTemplateParams(r)
params := struct { params.Title = "Campaigns"
User models.User
Title string
Flashes []interface{}
Token string
}{Title: "Campaigns", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
getTemplate(w, "campaigns").ExecuteTemplate(w, "base", params) getTemplate(w, "campaigns").ExecuteTemplate(w, "base", params)
} }
// CampaignID handles the default path and template execution // CampaignID handles the default path and template execution
func CampaignID(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) CampaignID(w http.ResponseWriter, r *http.Request) {
// Example of using session - will be removed. params := newTemplateParams(r)
params := struct { params.Title = "Campaign Results"
User models.User
Title string
Flashes []interface{}
Token string
}{Title: "Campaign Results", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
getTemplate(w, "campaign_results").ExecuteTemplate(w, "base", params) getTemplate(w, "campaign_results").ExecuteTemplate(w, "base", params)
} }
// Templates handles the default path and template execution // Templates handles the default path and template execution
func Templates(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) Templates(w http.ResponseWriter, r *http.Request) {
// Example of using session - will be removed. params := newTemplateParams(r)
params := struct { params.Title = "Email Templates"
User models.User
Title string
Flashes []interface{}
Token string
}{Title: "Email Templates", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
getTemplate(w, "templates").ExecuteTemplate(w, "base", params) getTemplate(w, "templates").ExecuteTemplate(w, "base", params)
} }
// Users handles the default path and template execution // Users handles the default path and template execution
func Users(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) Users(w http.ResponseWriter, r *http.Request) {
// Example of using session - will be removed. params := newTemplateParams(r)
params := struct { params.Title = "Users & Groups"
User models.User
Title string
Flashes []interface{}
Token string
}{Title: "Users & Groups", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
getTemplate(w, "users").ExecuteTemplate(w, "base", params) getTemplate(w, "users").ExecuteTemplate(w, "base", params)
} }
// LandingPages handles the default path and template execution // LandingPages handles the default path and template execution
func LandingPages(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) LandingPages(w http.ResponseWriter, r *http.Request) {
// Example of using session - will be removed. params := newTemplateParams(r)
params := struct { params.Title = "Landing Pages"
User models.User
Title string
Flashes []interface{}
Token string
}{Title: "Landing Pages", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
getTemplate(w, "landing_pages").ExecuteTemplate(w, "base", params) getTemplate(w, "landing_pages").ExecuteTemplate(w, "base", params)
} }
// SendingProfiles handles the default path and template execution // SendingProfiles handles the default path and template execution
func SendingProfiles(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) SendingProfiles(w http.ResponseWriter, r *http.Request) {
// Example of using session - will be removed. params := newTemplateParams(r)
params := struct { params.Title = "Sending Profiles"
User models.User
Title string
Flashes []interface{}
Token string
}{Title: "Sending Profiles", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
getTemplate(w, "sending_profiles").ExecuteTemplate(w, "base", params) getTemplate(w, "sending_profiles").ExecuteTemplate(w, "base", params)
} }
// Settings handles the changing of settings // Settings handles the changing of settings
func Settings(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) Settings(w http.ResponseWriter, r *http.Request) {
switch { switch {
case r.Method == "GET": case r.Method == "GET":
params := struct { params := newTemplateParams(r)
User models.User params.Title = "Settings"
Title string
Flashes []interface{}
Token string
Version string
}{Title: "Settings", Version: config.Version, User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
getTemplate(w, "settings").ExecuteTemplate(w, "base", params) getTemplate(w, "settings").ExecuteTemplate(w, "base", params)
case r.Method == "POST": case r.Method == "POST":
err := auth.ChangePassword(r) err := auth.ChangePassword(r)
@ -235,7 +290,7 @@ func Settings(w http.ResponseWriter, r *http.Request) {
// Login handles the authentication flow for a user. If credentials are valid, // Login handles the authentication flow for a user. If credentials are valid,
// a session is created // a session is created
func Login(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) Login(w http.ResponseWriter, r *http.Request) {
params := struct { params := struct {
User models.User User models.User
Title string Title string
@ -289,7 +344,7 @@ func Login(w http.ResponseWriter, r *http.Request) {
} }
// Logout destroys the current user session // Logout destroys the current user session
func Logout(w http.ResponseWriter, r *http.Request) { func (as *AdminServer) Logout(w http.ResponseWriter, r *http.Request) {
session := ctx.Get(r, "session").(*sessions.Session) session := ctx.Get(r, "session").(*sessions.Session)
delete(session.Values, "id") delete(session.Values, "id")
Flash(w, r, "success", "You have successfully logged out") Flash(w, r, "success", "You have successfully logged out")
@ -297,28 +352,6 @@ func Logout(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/login", 302) http.Redirect(w, r, "/login", 302)
} }
// Preview allows for the viewing of page html in a separate browser window
func Preview(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusBadRequest)
return
}
fmt.Fprintf(w, "%s", r.FormValue("html"))
}
// Clone takes a URL as a POST parameter and returns the site HTML
func Clone(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusBadRequest)
return
}
if url, ok := vars["url"]; ok {
log.Error(url)
}
http.Error(w, "No URL given.", http.StatusBadRequest)
}
func getTemplate(w http.ResponseWriter, tmpl string) *template.Template { func getTemplate(w http.ResponseWriter, tmpl string) *template.Template {
templates := template.New("template") templates := template.New("template")
_, err := templates.ParseFiles("templates/base.html", "templates/"+tmpl+".html", "templates/flashes.html") _, err := templates.ParseFiles("templates/base.html", "templates/"+tmpl+".html", "templates/flashes.html")

View File

@ -10,7 +10,7 @@ import (
) )
func (s *ControllersSuite) TestLoginCSRF() { func (s *ControllersSuite) TestLoginCSRF() {
resp, err := http.PostForm(fmt.Sprintf("%s/login", as.URL), resp, err := http.PostForm(fmt.Sprintf("%s/login", s.adminServer.URL),
url.Values{ url.Values{
"username": {"admin"}, "username": {"admin"},
"password": {"gophish"}, "password": {"gophish"},
@ -21,7 +21,7 @@ func (s *ControllersSuite) TestLoginCSRF() {
} }
func (s *ControllersSuite) TestInvalidCredentials() { func (s *ControllersSuite) TestInvalidCredentials() {
resp, err := http.Get(fmt.Sprintf("%s/login", as.URL)) resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL))
s.Equal(err, nil) s.Equal(err, nil)
s.Equal(resp.StatusCode, http.StatusOK) s.Equal(resp.StatusCode, http.StatusOK)
@ -32,7 +32,7 @@ func (s *ControllersSuite) TestInvalidCredentials() {
s.Equal(ok, true) s.Equal(ok, true)
client := &http.Client{} client := &http.Client{}
req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", as.URL), strings.NewReader(url.Values{ req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", s.adminServer.URL), strings.NewReader(url.Values{
"username": {"admin"}, "username": {"admin"},
"password": {"invalid"}, "password": {"invalid"},
"csrf_token": {token}, "csrf_token": {token},
@ -48,7 +48,7 @@ func (s *ControllersSuite) TestInvalidCredentials() {
} }
func (s *ControllersSuite) TestSuccessfulLogin() { func (s *ControllersSuite) TestSuccessfulLogin() {
resp, err := http.Get(fmt.Sprintf("%s/login", as.URL)) resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL))
s.Equal(err, nil) s.Equal(err, nil)
s.Equal(resp.StatusCode, http.StatusOK) s.Equal(resp.StatusCode, http.StatusOK)
@ -59,7 +59,7 @@ func (s *ControllersSuite) TestSuccessfulLogin() {
s.Equal(ok, true) s.Equal(ok, true)
client := &http.Client{} client := &http.Client{}
req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", as.URL), strings.NewReader(url.Values{ req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", s.adminServer.URL), strings.NewReader(url.Values{
"username": {"admin"}, "username": {"admin"},
"password": {"gophish"}, "password": {"gophish"},
"csrf_token": {token}, "csrf_token": {token},
@ -76,7 +76,7 @@ func (s *ControllersSuite) TestSuccessfulLogin() {
func (s *ControllersSuite) TestSuccessfulRedirect() { func (s *ControllersSuite) TestSuccessfulRedirect() {
next := "/campaigns" next := "/campaigns"
resp, err := http.Get(fmt.Sprintf("%s/login", as.URL)) resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL))
s.Equal(err, nil) s.Equal(err, nil)
s.Equal(resp.StatusCode, http.StatusOK) s.Equal(resp.StatusCode, http.StatusOK)
@ -91,7 +91,7 @@ func (s *ControllersSuite) TestSuccessfulRedirect() {
return http.ErrUseLastResponse return http.ErrUseLastResponse
}, },
} }
req, err := http.NewRequest("POST", fmt.Sprintf("%s/login?next=%s", as.URL, next), strings.NewReader(url.Values{ req, err := http.NewRequest("POST", fmt.Sprintf("%s/login?next=%s", s.adminServer.URL, next), strings.NewReader(url.Values{
"username": {"admin"}, "username": {"admin"},
"password": {"gophish"}, "password": {"gophish"},
"csrf_token": {token}, "csrf_token": {token},

View File

@ -1,35 +0,0 @@
package controllers
import (
"net/http"
"strings"
)
// UnindexedFileSystem is an implementation of a standard http.FileSystem
// without the ability to list files in the directory.
// This implementation is largely inspired by
// https://www.alexedwards.net/blog/disable-http-fileserver-directory-listings
type UnindexedFileSystem struct {
fs http.FileSystem
}
// Open returns a file from the static directory. If the requested path ends
// with a slash, there is a check for an index.html file. If none exists, then
// an error is returned.
func (ufs UnindexedFileSystem) Open(name string) (http.File, error) {
f, err := ufs.fs.Open(name)
if err != nil {
return nil, err
}
s, err := f.Stat()
if s.IsDir() {
index := strings.TrimSuffix(name, "/") + "/index.html"
indexFile, err := ufs.fs.Open(index)
if err != nil {
return nil, err
}
return indexFile, nil
}
return f, nil
}

View File

@ -1,81 +0,0 @@
package controllers
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"os"
"path/filepath"
)
var fileContent = []byte("Hello world")
func mustRemoveAll(dir string) {
err := os.RemoveAll(dir)
if err != nil {
panic(err)
}
}
func createTestFile(dir, filename string) error {
return ioutil.WriteFile(filepath.Join(dir, filename), fileContent, 0644)
}
func (s *ControllersSuite) TestGetStaticFile() {
dir, err := ioutil.TempDir("static/endpoint", "test-")
tempFolder := filepath.Base(dir)
s.Nil(err)
defer mustRemoveAll(dir)
err = createTestFile(dir, "foo.txt")
s.Nil(nil, err)
resp, err := http.Get(fmt.Sprintf("%s/static/%s/foo.txt", ps.URL, tempFolder))
s.Nil(err)
defer resp.Body.Close()
got, err := ioutil.ReadAll(resp.Body)
s.Nil(err)
s.Equal(bytes.Compare(fileContent, got), 0, fmt.Sprintf("Got %s", got))
}
func (s *ControllersSuite) TestStaticFileListing() {
dir, err := ioutil.TempDir("static/endpoint", "test-")
tempFolder := filepath.Base(dir)
s.Nil(err)
defer mustRemoveAll(dir)
err = createTestFile(dir, "foo.txt")
s.Nil(nil, err)
resp, err := http.Get(fmt.Sprintf("%s/static/%s/", ps.URL, tempFolder))
s.Nil(err)
defer resp.Body.Close()
s.Nil(err)
s.Equal(resp.StatusCode, http.StatusNotFound)
}
func (s *ControllersSuite) TestStaticIndex() {
dir, err := ioutil.TempDir("static/endpoint", "test-")
tempFolder := filepath.Base(dir)
s.Nil(err)
defer mustRemoveAll(dir)
err = createTestFile(dir, "index.html")
s.Nil(nil, err)
resp, err := http.Get(fmt.Sprintf("%s/static/%s/", ps.URL, tempFolder))
s.Nil(err)
defer resp.Body.Close()
got, err := ioutil.ReadAll(resp.Body)
s.Nil(err)
s.Equal(bytes.Compare(fileContent, got), 0, fmt.Sprintf("Got %s", got))
}

View File

@ -26,24 +26,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. THE SOFTWARE.
*/ */
import ( import (
"compress/gzip"
"context"
"io/ioutil" "io/ioutil"
"net/http"
"os" "os"
"sync" "os/signal"
"gopkg.in/alecthomas/kingpin.v2" "gopkg.in/alecthomas/kingpin.v2"
"github.com/NYTimes/gziphandler"
"github.com/gophish/gophish/auth" "github.com/gophish/gophish/auth"
"github.com/gophish/gophish/config" "github.com/gophish/gophish/config"
"github.com/gophish/gophish/controllers" "github.com/gophish/gophish/controllers"
log "github.com/gophish/gophish/logger" log "github.com/gophish/gophish/logger"
"github.com/gophish/gophish/mailer"
"github.com/gophish/gophish/models" "github.com/gophish/gophish/models"
"github.com/gophish/gophish/util"
"github.com/gorilla/handlers"
) )
var ( var (
@ -65,31 +58,25 @@ func main() {
kingpin.Parse() kingpin.Parse()
// Load the config // Load the config
err = config.LoadConfig(*configPath) conf, err := config.LoadConfig(*configPath)
// Just warn if a contact address hasn't been configured // Just warn if a contact address hasn't been configured
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
if config.Conf.ContactAddress == "" { if conf.ContactAddress == "" {
log.Warnf("No contact address has been configured.") log.Warnf("No contact address has been configured.")
log.Warnf("Please consider adding a contact_address entry in your config.json") log.Warnf("Please consider adding a contact_address entry in your config.json")
} }
config.Version = string(version) config.Version = string(version)
err = log.Setup() err = log.Setup(conf)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Provide the option to disable the built-in mailer // Provide the option to disable the built-in mailer
if !*disableMailer {
go mailer.Mailer.Start(ctx)
}
// Setup the global variables and settings // Setup the global variables and settings
err = models.Setup() err = models.Setup(conf)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -99,39 +86,27 @@ func main() {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
wg := &sync.WaitGroup{}
wg.Add(1) // Create our servers
// Start the web servers adminOptions := []controllers.AdminServerOption{}
go func() { if *disableMailer {
defer wg.Done() adminOptions = append(adminOptions, controllers.WithWorker(nil))
gzipWrapper, _ := gziphandler.NewGzipLevelHandler(gzip.BestCompression)
adminHandler := gzipWrapper(controllers.CreateAdminRouter())
auth.Store.Options.Secure = config.Conf.AdminConf.UseTLS
if config.Conf.AdminConf.UseTLS { // use TLS for Admin web server if available
err := util.CheckAndCreateSSL(config.Conf.AdminConf.CertPath, config.Conf.AdminConf.KeyPath)
if err != nil {
log.Fatal(err)
} }
log.Infof("Starting admin server at https://%s", config.Conf.AdminConf.ListenURL) adminConfig := conf.AdminConf
log.Info(http.ListenAndServeTLS(config.Conf.AdminConf.ListenURL, config.Conf.AdminConf.CertPath, config.Conf.AdminConf.KeyPath, adminServer := controllers.NewAdminServer(adminConfig, adminOptions...)
handlers.CombinedLoggingHandler(log.Writer(), adminHandler))) auth.Store.Options.Secure = adminConfig.UseTLS
} else {
log.Infof("Starting admin server at http://%s", config.Conf.AdminConf.ListenURL) phishConfig := conf.PhishConf
log.Info(http.ListenAndServe(config.Conf.AdminConf.ListenURL, handlers.CombinedLoggingHandler(os.Stdout, adminHandler))) phishServer := controllers.NewPhishingServer(phishConfig)
}
}() go adminServer.Start()
wg.Add(1) go phishServer.Start()
go func() {
defer wg.Done() // Handle graceful shutdown
phishHandler := gziphandler.GzipHandler(controllers.CreatePhishingRouter()) c := make(chan os.Signal, 1)
if config.Conf.PhishConf.UseTLS { // use TLS for Phish web server if available signal.Notify(c, os.Interrupt)
log.Infof("Starting phishing server at https://%s", config.Conf.PhishConf.ListenURL) <-c
log.Info(http.ListenAndServeTLS(config.Conf.PhishConf.ListenURL, config.Conf.PhishConf.CertPath, config.Conf.PhishConf.KeyPath, log.Info("CTRL+C Received... Gracefully shutting down servers")
handlers.CombinedLoggingHandler(log.Writer(), phishHandler))) adminServer.Shutdown()
} else { phishServer.Shutdown()
log.Infof("Starting phishing server at http://%s", config.Conf.PhishConf.ListenURL)
log.Fatal(http.ListenAndServe(config.Conf.PhishConf.ListenURL, handlers.CombinedLoggingHandler(os.Stdout, phishHandler)))
}
}()
wg.Wait()
} }

View File

@ -18,10 +18,10 @@ func init() {
} }
// Setup configures the logger based on options in the config.json. // Setup configures the logger based on options in the config.json.
func Setup() error { func Setup(conf *config.Config) error {
Logger.SetLevel(logrus.InfoLevel) Logger.SetLevel(logrus.InfoLevel)
// Set up logging to a file if specified in the config // Set up logging to a file if specified in the config
logFile := config.Conf.Logging.Filename logFile := conf.Logging.Filename
if logFile != "" { if logFile != "" {
f, err := os.OpenFile(logFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644) f, err := os.OpenFile(logFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644)
if err != nil { if err != nil {

View File

@ -29,6 +29,13 @@ func (e *ErrMaxConnectAttempts) Error() string {
return errString return errString
} }
// Mailer is an interface that defines an object used to queue and
// send mailer.Mail instances.
type Mailer interface {
Start(ctx context.Context)
Queue([]Mail)
}
// Sender exposes the common operations required for sending email. // Sender exposes the common operations required for sending email.
type Sender interface { type Sender interface {
Send(from string, to []string, msg io.WriterTo) error Send(from string, to []string, msg io.WriterTo) error
@ -50,27 +57,18 @@ type Mail interface {
GetDialer() (Dialer, error) GetDialer() (Dialer, error)
} }
// Mailer is a global instance of the mailer that can
// be used in applications. It is the responsibility of the application
// to call Mailer.Start()
var Mailer *MailWorker
func init() {
Mailer = NewMailWorker()
}
// MailWorker is the worker that receives slices of emails // MailWorker is the worker that receives slices of emails
// on a channel to send. It's assumed that every slice of emails received is meant // on a channel to send. It's assumed that every slice of emails received is meant
// to be sent to the same server. // to be sent to the same server.
type MailWorker struct { type MailWorker struct {
Queue chan []Mail queue chan []Mail
} }
// NewMailWorker returns an instance of MailWorker with the mail queue // NewMailWorker returns an instance of MailWorker with the mail queue
// initialized. // initialized.
func NewMailWorker() *MailWorker { func NewMailWorker() *MailWorker {
return &MailWorker{ return &MailWorker{
Queue: make(chan []Mail), queue: make(chan []Mail),
} }
} }
@ -81,7 +79,7 @@ func (mw *MailWorker) Start(ctx context.Context) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case ms := <-mw.Queue: case ms := <-mw.queue:
go func(ctx context.Context, ms []Mail) { go func(ctx context.Context, ms []Mail) {
dialer, err := ms[0].GetDialer() dialer, err := ms[0].GetDialer()
if err != nil { if err != nil {
@ -94,6 +92,11 @@ func (mw *MailWorker) Start(ctx context.Context) {
} }
} }
// Queue sends the provided mail to the internal queue for processing.
func (mw *MailWorker) Queue(ms []Mail) {
mw.queue <- ms
}
// errorMail is a helper to handle erroring out a slice of Mail instances // errorMail is a helper to handle erroring out a slice of Mail instances
// in the case that an unrecoverable error occurs. // in the case that an unrecoverable error occurs.
func errorMail(err error, ms []Mail) { func errorMail(err error, ms []Mail) {

View File

@ -88,7 +88,7 @@ func (ms *MailerSuite) TestMailWorkerStart() {
messages := generateMessages(dialer) messages := generateMessages(dialer)
// Send the campaign // Send the campaign
mw.Queue <- messages mw.Queue(messages)
got := []*mockMessage{} got := []*mockMessage{}
@ -129,7 +129,7 @@ func (ms *MailerSuite) TestBackoff() {
messages := generateMessages(dialer) messages := generateMessages(dialer)
// Send the campaign // Send the campaign
mw.Queue <- messages mw.Queue(messages)
got := []*mockMessage{} got := []*mockMessage{}
@ -183,7 +183,7 @@ func (ms *MailerSuite) TestPermError() {
messages := generateMessages(dialer) messages := generateMessages(dialer)
// Send the campaign // Send the campaign
mw.Queue <- messages mw.Queue(messages)
got := []*mockMessage{} got := []*mockMessage{}
@ -242,7 +242,7 @@ func (ms *MailerSuite) TestUnknownError() {
messages := generateMessages(dialer) messages := generateMessages(dialer)
// Send the campaign // Send the campaign
mw.Queue <- messages mw.Queue(messages)
got := []*mockMessage{} got := []*mockMessage{}

View File

@ -60,8 +60,10 @@ func GetContext(handler http.Handler) http.HandlerFunc {
} }
} }
func RequireAPIKey(handler http.Handler) http.HandlerFunc { // RequireAPIKey ensures that a valid API key is set as either the api_key GET
return func(w http.ResponseWriter, r *http.Request) { // parameter, or a Bearer token.
func RequireAPIKey(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Origin", "*")
if r.Method == "OPTIONS" { if r.Method == "OPTIONS" {
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
@ -81,18 +83,18 @@ func RequireAPIKey(handler http.Handler) http.HandlerFunc {
} }
} }
if ak == "" { if ak == "" {
JSONError(w, 400, "API Key not set") JSONError(w, http.StatusUnauthorized, "API Key not set")
return return
} }
u, err := models.GetUserByAPIKey(ak) u, err := models.GetUserByAPIKey(ak)
if err != nil { if err != nil {
JSONError(w, 400, "Invalid API Key") JSONError(w, http.StatusUnauthorized, "Invalid API Key")
return return
} }
r = ctx.Set(r, "user_id", u.Id) r = ctx.Set(r, "user_id", u.Id)
r = ctx.Set(r, "api_key", ak) r = ctx.Set(r, "api_key", ak)
handler.ServeHTTP(w, r) handler.ServeHTTP(w, r)
} })
} }
// RequireLogin is a simple middleware which checks to see if the user is currently logged in. // RequireLogin is a simple middleware which checks to see if the user is currently logged in.
@ -104,7 +106,7 @@ func RequireLogin(handler http.Handler) http.HandlerFunc {
} else { } else {
q := r.URL.Query() q := r.URL.Query()
q.Set("next", r.URL.Path) q.Set("next", r.URL.Path)
http.Redirect(w, r, fmt.Sprintf("/login?%s", q.Encode()), 302) http.Redirect(w, r, fmt.Sprintf("/login?%s", q.Encode()), http.StatusTemporaryRedirect)
} }
} }
} }

View File

@ -165,7 +165,7 @@ func (c *Campaign) AddEvent(e *Event) error {
// an error is returned. Otherwise, the attribute name is set to [Deleted], // an error is returned. Otherwise, the attribute name is set to [Deleted],
// indicating the user deleted the attribute (template, smtp, etc.) // indicating the user deleted the attribute (template, smtp, etc.)
func (c *Campaign) getDetails() error { func (c *Campaign) getDetails() error {
err = db.Model(c).Related(&c.Results).Error err := db.Model(c).Related(&c.Results).Error
if err != nil { if err != nil {
log.Warnf("%s: results not found for campaign", err) log.Warnf("%s: results not found for campaign", err)
return err return err
@ -402,7 +402,8 @@ func GetQueuedCampaigns(t time.Time) ([]Campaign, error) {
// PostCampaign inserts a campaign and all associated records into the database. // PostCampaign inserts a campaign and all associated records into the database.
func PostCampaign(c *Campaign, uid int64) error { func PostCampaign(c *Campaign, uid int64) error {
if err := c.Validate(); err != nil { err := c.Validate()
if err != nil {
return err return err
} }
// Fill in the details // Fill in the details
@ -514,14 +515,16 @@ func PostCampaign(c *Campaign, uid int64) error {
Reported: false, Reported: false,
ModifiedDate: c.CreatedDate, ModifiedDate: c.CreatedDate,
} }
if r.SendDate.Before(c.CreatedDate) || r.SendDate.Equal(c.CreatedDate) {
r.Status = STATUS_SENDING
}
err = r.GenerateId() err = r.GenerateId()
if err != nil { if err != nil {
log.Error(err) log.Error(err)
continue continue
} }
processing := false
if r.SendDate.Before(c.CreatedDate) || r.SendDate.Equal(c.CreatedDate) {
r.Status = STATUS_SENDING
processing = true
}
err = db.Save(r).Error err = db.Save(r).Error
if err != nil { if err != nil {
log.WithFields(logrus.Fields{ log.WithFields(logrus.Fields{
@ -530,7 +533,14 @@ func PostCampaign(c *Campaign, uid int64) error {
} }
c.Results = append(c.Results, *r) c.Results = append(c.Results, *r)
log.Infof("Creating maillog for %s to send at %s\n", r.Email, sendDate) log.Infof("Creating maillog for %s to send at %s\n", r.Email, sendDate)
err = GenerateMailLog(c, r, sendDate) m := &MailLog{
UserId: c.UserId,
CampaignId: c.Id,
RId: r.RId,
SendDate: sendDate,
Processing: processing,
}
err = db.Save(m).Error
if err != nil { if err != nil {
log.Error(err) log.Error(err)
continue continue

View File

@ -79,3 +79,29 @@ func (s *ModelsSuite) TestCampaignDateValidation(c *check.C) {
err = campaign.Validate() err = campaign.Validate()
c.Assert(err, check.Equals, ErrInvalidSendByDate) c.Assert(err, check.Equals, ErrInvalidSendByDate)
} }
func (s *ModelsSuite) TestLaunchCampaignMaillogStatus(c *check.C) {
// For the first test, ensure that campaigns created with the zero date
// (and therefore are set to launch immediately) have maillogs that are
// locked to prevent race conditions.
campaign := s.createCampaign(c)
ms, err := GetMailLogsByCampaign(campaign.Id)
c.Assert(err, check.Equals, nil)
for _, m := range ms {
c.Assert(m.Processing, check.Equals, true)
}
// Next, verify that campaigns scheduled in the future do not lock the
// maillogs so that they can be picked up by the background worker.
campaign = s.createCampaignDependencies(c)
campaign.Name = "New Campaign"
campaign.LaunchDate = time.Now().Add(1 * time.Hour)
c.Assert(PostCampaign(&campaign, campaign.UserId), check.Equals, nil)
ms, err = GetMailLogsByCampaign(campaign.Id)
c.Assert(err, check.Equals, nil)
for _, m := range ms {
c.Assert(m.Processing, check.Equals, false)
}
}

View File

@ -122,8 +122,8 @@ func (s *EmailRequest) Generate(msg *gomail.Message) error {
// Add the transparency headers // Add the transparency headers
msg.SetHeader("X-Mailer", config.ServerName) msg.SetHeader("X-Mailer", config.ServerName)
if config.Conf.ContactAddress != "" { if conf.ContactAddress != "" {
msg.SetHeader("X-Gophish-Contact", config.Conf.ContactAddress) msg.SetHeader("X-Gophish-Contact", conf.ContactAddress)
} }
// Parse the customHeader templates // Parse the customHeader templates

View File

@ -26,7 +26,7 @@ func (s *ModelsSuite) TestEmailRequestBackoff(ch *check.C) {
} }
expected := errors.New("Temporary Error") expected := errors.New("Temporary Error")
go func() { go func() {
err = req.Backoff(expected) err := req.Backoff(expected)
ch.Assert(err, check.Equals, nil) ch.Assert(err, check.Equals, nil)
}() }()
ch.Assert(<-req.ErrorChan, check.Equals, expected) ch.Assert(<-req.ErrorChan, check.Equals, expected)
@ -38,7 +38,7 @@ func (s *ModelsSuite) TestEmailRequestError(ch *check.C) {
} }
expected := errors.New("Temporary Error") expected := errors.New("Temporary Error")
go func() { go func() {
err = req.Error(expected) err := req.Error(expected)
ch.Assert(err, check.Equals, nil) ch.Assert(err, check.Equals, nil)
}() }()
ch.Assert(<-req.ErrorChan, check.Equals, expected) ch.Assert(<-req.ErrorChan, check.Equals, expected)
@ -49,7 +49,7 @@ func (s *ModelsSuite) TestEmailRequestSuccess(ch *check.C) {
ErrorChan: make(chan error), ErrorChan: make(chan error),
} }
go func() { go func() {
err = req.Success() err := req.Success()
ch.Assert(err, check.Equals, nil) ch.Assert(err, check.Equals, nil)
}() }()
ch.Assert(<-req.ErrorChan, check.Equals, nil) ch.Assert(<-req.ErrorChan, check.Equals, nil)
@ -76,14 +76,14 @@ func (s *ModelsSuite) TestEmailRequestGenerate(ch *check.C) {
FromAddress: smtp.FromAddress, FromAddress: smtp.FromAddress,
} }
config.Conf.ContactAddress = "test@test.com" s.config.ContactAddress = "test@test.com"
expectedHeaders := map[string]string{ expectedHeaders := map[string]string{
"X-Mailer": config.ServerName, "X-Mailer": config.ServerName,
"X-Gophish-Contact": config.Conf.ContactAddress, "X-Gophish-Contact": s.config.ContactAddress,
} }
msg := gomail.NewMessage() msg := gomail.NewMessage()
err = req.Generate(msg) err := req.Generate(msg)
ch.Assert(err, check.Equals, nil) ch.Assert(err, check.Equals, nil)
expected := &email.Email{ expected := &email.Email{
@ -130,7 +130,7 @@ func (s *ModelsSuite) TestEmailRequestURLTemplating(ch *check.C) {
} }
msg := gomail.NewMessage() msg := gomail.NewMessage()
err = req.Generate(msg) err := req.Generate(msg)
ch.Assert(err, check.Equals, nil) ch.Assert(err, check.Equals, nil)
expectedURL := fmt.Sprintf("http://127.0.0.1/%s?%s=%s", req.Email, RecipientParameter, req.RId) expectedURL := fmt.Sprintf("http://127.0.0.1/%s?%s=%s", req.Email, RecipientParameter, req.RId)
@ -167,7 +167,7 @@ func (s *ModelsSuite) TestEmailRequestGenerateEmptySubject(ch *check.C) {
} }
msg := gomail.NewMessage() msg := gomail.NewMessage()
err = req.Generate(msg) err := req.Generate(msg)
ch.Assert(err, check.Equals, nil) ch.Assert(err, check.Equals, nil)
expected := &email.Email{ expected := &email.Email{

View File

@ -196,7 +196,7 @@ func PostGroup(g *Group) error {
return err return err
} }
// Insert the group into the DB // Insert the group into the DB
err = db.Save(g).Error err := db.Save(g).Error
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return err return err
@ -214,7 +214,7 @@ func PutGroup(g *Group) error {
} }
// Fetch group's existing targets from database. // Fetch group's existing targets from database.
ts := []Target{} ts := []Target{}
ts, err = GetTargets(g.Id) ts, err := GetTargets(g.Id)
if err != nil { if err != nil {
log.WithFields(logrus.Fields{ log.WithFields(logrus.Fields{
"group_id": g.Id, "group_id": g.Id,
@ -234,7 +234,7 @@ func PutGroup(g *Group) error {
} }
// If the target does not exist in the group any longer, we delete it // If the target does not exist in the group any longer, we delete it
if !tExists { if !tExists {
err = db.Where("group_id=? and target_id=?", g.Id, t.Id).Delete(&GroupTarget{}).Error err := db.Where("group_id=? and target_id=?", g.Id, t.Id).Delete(&GroupTarget{}).Error
if err != nil { if err != nil {
log.WithFields(logrus.Fields{ log.WithFields(logrus.Fields{
"email": t.Email, "email": t.Email,
@ -286,14 +286,14 @@ func DeleteGroup(g *Group) error {
} }
func insertTargetIntoGroup(t Target, gid int64) error { func insertTargetIntoGroup(t Target, gid int64) error {
if _, err = mail.ParseAddress(t.Email); err != nil { if _, err := mail.ParseAddress(t.Email); err != nil {
log.WithFields(logrus.Fields{ log.WithFields(logrus.Fields{
"email": t.Email, "email": t.Email,
}).Error("Invalid email") }).Error("Invalid email")
return err return err
} }
trans := db.Begin() trans := db.Begin()
err = trans.Where(t).FirstOrCreate(&t).Error err := trans.Where(t).FirstOrCreate(&t).Error
if err != nil { if err != nil {
log.WithFields(logrus.Fields{ log.WithFields(logrus.Fields{
"email": t.Email, "email": t.Email,

View File

@ -45,8 +45,7 @@ func GenerateMailLog(c *Campaign, r *Result, sendDate time.Time) error {
RId: r.RId, RId: r.RId,
SendDate: sendDate, SendDate: sendDate,
} }
err = db.Save(m).Error return db.Save(m).Error
return err
} }
// Backoff sets the MailLog SendDate to be the next entry in an exponential // Backoff sets the MailLog SendDate to be the next entry in an exponential
@ -160,8 +159,8 @@ func (m *MailLog) Generate(msg *gomail.Message) error {
// Add the transparency headers // Add the transparency headers
msg.SetHeader("X-Mailer", config.ServerName) msg.SetHeader("X-Mailer", config.ServerName)
if config.Conf.ContactAddress != "" { if conf.ContactAddress != "" {
msg.SetHeader("X-Gophish-Contact", config.Conf.ContactAddress) msg.SetHeader("X-Gophish-Contact", conf.ContactAddress)
} }
// Parse the customHeader templates // Parse the customHeader templates
for _, header := range c.SMTP.Headers { for _, header := range c.SMTP.Headers {
@ -260,6 +259,5 @@ func LockMailLogs(ms []*MailLog, lock bool) error {
// in the database. This is intended to be called when Gophish is started // in the database. This is intended to be called when Gophish is started
// so that any previously locked maillogs can resume processing. // so that any previously locked maillogs can resume processing.
func UnlockAllMailLogs() error { func UnlockAllMailLogs() error {
err = db.Model(&MailLog{}).Update("processing", false).Error return db.Model(&MailLog{}).Update("processing", false).Error
return err
} }

View File

@ -37,7 +37,13 @@ func (s *ModelsSuite) emailFromFirstMailLog(campaign Campaign, ch *check.C) *ema
func (s *ModelsSuite) TestGetQueuedMailLogs(ch *check.C) { func (s *ModelsSuite) TestGetQueuedMailLogs(ch *check.C) {
campaign := s.createCampaign(ch) campaign := s.createCampaign(ch)
ms, err := GetQueuedMailLogs(campaign.LaunchDate) // By default, for campaigns with no launch date, the maillogs are set as
// being processed. We need to unlock them first.
ms, err := GetMailLogsByCampaign(campaign.Id)
ch.Assert(err, check.Equals, nil)
err = LockMailLogs(ms, false)
ch.Assert(err, check.Equals, nil)
ms, err = GetQueuedMailLogs(campaign.LaunchDate)
ch.Assert(err, check.Equals, nil) ch.Assert(err, check.Equals, nil)
got := make(map[string]*MailLog) got := make(map[string]*MailLog)
for _, m := range ms { for _, m := range ms {
@ -222,10 +228,10 @@ func (s *ModelsSuite) TestMailLogGenerate(ch *check.C) {
} }
func (s *ModelsSuite) TestMailLogGenerateTransparencyHeaders(ch *check.C) { func (s *ModelsSuite) TestMailLogGenerateTransparencyHeaders(ch *check.C) {
config.Conf.ContactAddress = "test@test.com" s.config.ContactAddress = "test@test.com"
expectedHeaders := map[string]string{ expectedHeaders := map[string]string{
"X-Mailer": config.ServerName, "X-Mailer": config.ServerName,
"X-Gophish-Contact": config.Conf.ContactAddress, "X-Gophish-Contact": s.config.ContactAddress,
} }
campaign := s.createCampaign(ch) campaign := s.createCampaign(ch)
got := s.emailFromFirstMailLog(campaign, ch) got := s.emailFromFirstMailLog(campaign, ch)
@ -264,12 +270,6 @@ func (s *ModelsSuite) TestUnlockAllMailLogs(ch *check.C) {
campaign := s.createCampaign(ch) campaign := s.createCampaign(ch)
ms, err := GetMailLogsByCampaign(campaign.Id) ms, err := GetMailLogsByCampaign(campaign.Id)
ch.Assert(err, check.Equals, nil) ch.Assert(err, check.Equals, nil)
for _, m := range ms {
ch.Assert(m.Processing, check.Equals, false)
}
err = LockMailLogs(ms, true)
ms, err = GetMailLogsByCampaign(campaign.Id)
ch.Assert(err, check.Equals, nil)
for _, m := range ms { for _, m := range ms {
ch.Assert(m.Processing, check.Equals, true) ch.Assert(m.Processing, check.Equals, true)
} }

View File

@ -15,7 +15,7 @@ import (
) )
var db *gorm.DB var db *gorm.DB
var err error var conf *config.Config
const ( const (
CAMPAIGN_IN_PROGRESS string = "In progress" CAMPAIGN_IN_PROGRESS string = "In progress"
@ -78,12 +78,14 @@ func chooseDBDriver(name, openStr string) goose.DBDriver {
// Setup initializes the Conn object // Setup initializes the Conn object
// It also populates the Gophish Config object // It also populates the Gophish Config object
func Setup() error { func Setup(c *config.Config) error {
// Setup the package-scoped config
conf = c
// Setup the goose configuration // Setup the goose configuration
migrateConf := &goose.DBConf{ migrateConf := &goose.DBConf{
MigrationsDir: config.Conf.MigrationsPath, MigrationsDir: conf.MigrationsPath,
Env: "production", Env: "production",
Driver: chooseDBDriver(config.Conf.DBName, config.Conf.DBPath), Driver: chooseDBDriver(conf.DBName, conf.DBPath),
} }
// Get the latest possible migration // Get the latest possible migration
latest, err := goose.GetMostRecentDBVersion(migrateConf.MigrationsDir) latest, err := goose.GetMostRecentDBVersion(migrateConf.MigrationsDir)
@ -92,7 +94,7 @@ func Setup() error {
return err return err
} }
// Open our database connection // Open our database connection
db, err = gorm.Open(config.Conf.DBName, config.Conf.DBPath) db, err = gorm.Open(conf.DBName, conf.DBPath)
db.LogMode(false) db.LogMode(false)
db.SetLogger(log.Logger) db.SetLogger(log.Logger)
db.DB().SetMaxOpenConns(1) db.DB().SetMaxOpenConns(1)

View File

@ -10,15 +10,20 @@ import (
// Hook up gocheck into the "go test" runner. // Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) } func Test(t *testing.T) { check.TestingT(t) }
type ModelsSuite struct{} type ModelsSuite struct {
config *config.Config
}
var _ = check.Suite(&ModelsSuite{}) var _ = check.Suite(&ModelsSuite{})
func (s *ModelsSuite) SetUpSuite(c *check.C) { func (s *ModelsSuite) SetUpSuite(c *check.C) {
config.Conf.DBName = "sqlite3" conf := &config.Config{
config.Conf.DBPath = ":memory:" DBName: "sqlite3",
config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/" DBPath: ":memory:",
err := Setup() MigrationsPath: "../db/db_sqlite3/migrations/",
}
s.config = conf
err := Setup(conf)
if err != nil { if err != nil {
c.Fatalf("Failed creating database: %v", err) c.Fatalf("Failed creating database: %v", err)
} }

View File

@ -148,7 +148,7 @@ func PutPage(p *Page) error {
// DeletePage deletes an existing page in the database. // DeletePage deletes an existing page in the database.
// An error is returned if a page with the given user id and page id is not found. // An error is returned if a page with the given user id and page id is not found.
func DeletePage(id int64, uid int64) error { func DeletePage(id int64, uid int64) error {
err = db.Where("user_id=?", uid).Delete(Page{Id: id}).Error err := db.Where("user_id=?", uid).Delete(Page{Id: id}).Error
if err != nil { if err != nil {
log.Error(err) log.Error(err)
} }

View File

@ -228,7 +228,7 @@ func PutSMTP(s *SMTP) error {
// An error is returned if a SMTP with the given user id and SMTP id is not found. // An error is returned if a SMTP with the given user id and SMTP id is not found.
func DeleteSMTP(id int64, uid int64) error { func DeleteSMTP(id int64, uid int64) error {
// Delete all custom headers // Delete all custom headers
err = db.Where("smtp_id=?", id).Delete(&Header{}).Error err := db.Where("smtp_id=?", id).Delete(&Header{}).Error
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return err return err

View File

@ -15,7 +15,7 @@ func (s *ModelsSuite) TestPostSMTP(c *check.C) {
FromAddress: "Foo Bar <foo@example.com>", FromAddress: "Foo Bar <foo@example.com>",
UserId: 1, UserId: 1,
} }
err = PostSMTP(&smtp) err := PostSMTP(&smtp)
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)
ss, err := GetSMTPs(1) ss, err := GetSMTPs(1)
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)
@ -28,7 +28,7 @@ func (s *ModelsSuite) TestPostSMTPNoHost(c *check.C) {
FromAddress: "Foo Bar <foo@example.com>", FromAddress: "Foo Bar <foo@example.com>",
UserId: 1, UserId: 1,
} }
err = PostSMTP(&smtp) err := PostSMTP(&smtp)
c.Assert(err, check.Equals, ErrHostNotSpecified) c.Assert(err, check.Equals, ErrHostNotSpecified)
} }
@ -38,7 +38,7 @@ func (s *ModelsSuite) TestPostSMTPNoFrom(c *check.C) {
UserId: 1, UserId: 1,
Host: "1.1.1.1:25", Host: "1.1.1.1:25",
} }
err = PostSMTP(&smtp) err := PostSMTP(&smtp)
c.Assert(err, check.Equals, ErrFromAddressNotSpecified) c.Assert(err, check.Equals, ErrFromAddressNotSpecified)
} }
@ -53,7 +53,7 @@ func (s *ModelsSuite) TestPostSMTPValidHeader(c *check.C) {
Header{Key: "X-Mailer", Value: "gophish"}, Header{Key: "X-Mailer", Value: "gophish"},
}, },
} }
err = PostSMTP(&smtp) err := PostSMTP(&smtp)
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)
ss, err := GetSMTPs(1) ss, err := GetSMTPs(1)
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, nil)

View File

@ -34,10 +34,10 @@ func (t *Template) Validate() error {
case t.Text == "" && t.HTML == "": case t.Text == "" && t.HTML == "":
return ErrTemplateMissingParameter return ErrTemplateMissingParameter
} }
if err = ValidateTemplate(t.HTML); err != nil { if err := ValidateTemplate(t.HTML); err != nil {
return err return err
} }
if err = ValidateTemplate(t.Text); err != nil { if err := ValidateTemplate(t.Text); err != nil {
return err return err
} }
return nil return nil
@ -113,7 +113,7 @@ func PostTemplate(t *Template) error {
if err := t.Validate(); err != nil { if err := t.Validate(); err != nil {
return err return err
} }
err = db.Save(t).Error err := db.Save(t).Error
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return err return err
@ -138,7 +138,7 @@ func PutTemplate(t *Template) error {
return err return err
} }
// Delete all attachments, and replace with new ones // Delete all attachments, and replace with new ones
err = db.Where("template_id=?", t.Id).Delete(&Attachment{}).Error err := db.Where("template_id=?", t.Id).Delete(&Attachment{}).Error
if err != nil && err != gorm.ErrRecordNotFound { if err != nil && err != gorm.ErrRecordNotFound {
log.Error(err) log.Error(err)
return err return err
@ -146,7 +146,7 @@ func PutTemplate(t *Template) error {
if err == gorm.ErrRecordNotFound { if err == gorm.ErrRecordNotFound {
err = nil err = nil
} }
for i, _ := range t.Attachments { for i := range t.Attachments {
t.Attachments[i].TemplateId = t.Id t.Attachments[i].TemplateId = t.Id
err := db.Save(&t.Attachments[i]).Error err := db.Save(&t.Attachments[i]).Error
if err != nil { if err != nil {

View File

@ -18,10 +18,12 @@ type UtilSuite struct {
} }
func (s *UtilSuite) SetupSuite() { func (s *UtilSuite) SetupSuite() {
config.Conf.DBName = "sqlite3" conf := &config.Config{
config.Conf.DBPath = ":memory:" DBName: "sqlite3",
config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/" DBPath: ":memory:",
err := models.Setup() MigrationsPath: "../db/db_sqlite3/migrations/",
}
err := models.Setup(conf)
if err != nil { if err != nil {
s.T().Fatalf("Failed creating database: %v", err) s.T().Fatalf("Failed creating database: %v", err)
} }

View File

@ -1,6 +1,7 @@
package worker package worker
import ( import (
"context"
"time" "time"
log "github.com/gophish/gophish/logger" log "github.com/gophish/gophish/logger"
@ -9,18 +10,46 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// Worker is the background worker that handles watching for new campaigns and sending emails appropriately. // Worker is an interface that defines the operations needed for a background worker
type Worker struct{} type Worker interface {
Start()
LaunchCampaign(c models.Campaign)
SendTestEmail(s *models.EmailRequest) error
}
// DefaultWorker is the background worker that handles watching for new campaigns and sending emails appropriately.
type DefaultWorker struct {
mailer mailer.Mailer
}
// New creates a new worker object to handle the creation of campaigns // New creates a new worker object to handle the creation of campaigns
func New() *Worker { func New(options ...func(Worker) error) (Worker, error) {
return &Worker{} defaultMailer := mailer.NewMailWorker()
w := &DefaultWorker{
mailer: defaultMailer,
}
for _, opt := range options {
if err := opt(w); err != nil {
return nil, err
}
}
return w, nil
}
// WithMailer sets the mailer for a given worker.
// By default, workers use a standard, default mailworker.
func WithMailer(m mailer.Mailer) func(*DefaultWorker) error {
return func(w *DefaultWorker) error {
w.mailer = m
return nil
}
} }
// Start launches the worker to poll the database every minute for any pending maillogs // Start launches the worker to poll the database every minute for any pending maillogs
// that need to be processed. // that need to be processed.
func (w *Worker) Start() { func (w *DefaultWorker) Start() {
log.Info("Background Worker Started Successfully - Waiting for Campaigns") log.Info("Background Worker Started Successfully - Waiting for Campaigns")
go w.mailer.Start(context.Background())
for t := range time.Tick(1 * time.Minute) { for t := range time.Tick(1 * time.Minute) {
ms, err := models.GetQueuedMailLogs(t.UTC()) ms, err := models.GetQueuedMailLogs(t.UTC())
if err != nil { if err != nil {
@ -62,14 +91,14 @@ func (w *Worker) Start() {
log.WithFields(logrus.Fields{ log.WithFields(logrus.Fields{
"num_emails": len(msc), "num_emails": len(msc),
}).Info("Sending emails to mailer for processing") }).Info("Sending emails to mailer for processing")
mailer.Mailer.Queue <- msc w.mailer.Queue(msc)
}(cid, msc) }(cid, msc)
} }
} }
} }
// LaunchCampaign starts a campaign // LaunchCampaign starts a campaign
func (w *Worker) LaunchCampaign(c models.Campaign) { func (w *DefaultWorker) LaunchCampaign(c models.Campaign) {
ms, err := models.GetMailLogsByCampaign(c.Id) ms, err := models.GetMailLogsByCampaign(c.Id)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
@ -89,14 +118,14 @@ func (w *Worker) LaunchCampaign(c models.Campaign) {
} }
mailEntries = append(mailEntries, m) mailEntries = append(mailEntries, m)
} }
mailer.Mailer.Queue <- mailEntries w.mailer.Queue(mailEntries)
} }
// SendTestEmail sends a test email // SendTestEmail sends a test email
func (w *Worker) SendTestEmail(s *models.EmailRequest) error { func (w *DefaultWorker) SendTestEmail(s *models.EmailRequest) error {
go func() { go func() {
ms := []mailer.Mail{s} ms := []mailer.Mail{s}
mailer.Mailer.Queue <- ms w.mailer.Queue(ms)
}() }()
return <-s.ErrorChan return <-s.ErrorChan
} }

View File

@ -9,17 +9,20 @@ import (
// WorkerSuite is a suite of tests to cover API related functions // WorkerSuite is a suite of tests to cover API related functions
type WorkerSuite struct { type WorkerSuite struct {
suite.Suite suite.Suite
ApiKey string config *config.Config
} }
func (s *WorkerSuite) SetupSuite() { func (s *WorkerSuite) SetupSuite() {
config.Conf.DBName = "sqlite3" conf := &config.Config{
config.Conf.DBPath = ":memory:" DBName: "sqlite3",
config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/" DBPath: ":memory:",
err := models.Setup() MigrationsPath: "../db/db_sqlite3/migrations/",
}
err := models.Setup(conf)
if err != nil { if err != nil {
s.T().Fatalf("Failed creating database: %v", err) s.T().Fatalf("Failed creating database: %v", err)
} }
s.config = conf
s.Nil(err) s.Nil(err)
} }
@ -31,7 +34,7 @@ func (s *WorkerSuite) TearDownTest() {
} }
func (s *WorkerSuite) SetupTest() { func (s *WorkerSuite) SetupTest() {
config.Conf.TestFlag = true s.config.TestFlag = true
// Add a group // Add a group
group := models.Group{Name: "Test Group"} group := models.Group{Name: "Test Group"}
group.Targets = []models.Target{ group.Targets = []models.Target{
@ -73,7 +76,3 @@ func (s *WorkerSuite) SetupTest() {
models.PostCampaign(&c, c.UserId) models.PostCampaign(&c, c.UserId)
c.UpdateStatus(models.CAMPAIGN_EMAILS_SENT) c.UpdateStatus(models.CAMPAIGN_EMAILS_SENT)
} }
func (s *WorkerSuite) TestMailSendSuccess() {
// TODO
}