mirror of https://github.com/gophish/gophish
Refactor servers (#1321)
* Refactoring servers to support custom workers and graceful shutdown. * Refactoring workers to support custom mailers. * Refactoring mailer to be an interface, with proper instances instead of a single global instance * Cleaning up a few things. Locking maillogs for campaigns set to launch immediately to prevent a race condition. * Cleaning up API middleware to be simpler * Moving template parameters to separate struct * Changed LoadConfig to return config object * Cleaned up some error handling, removing uninitialized global error in models package * Changed static file serving to use the unindexed packagepull/1323/head
parent
3b248d25c7
commit
47f0049c30
|
@ -38,9 +38,6 @@ type Config struct {
|
|||
Logging LoggingConfig `json:"logging"`
|
||||
}
|
||||
|
||||
// Conf contains the initialized configuration struct
|
||||
var Conf Config
|
||||
|
||||
// Version contains the current gophish version
|
||||
var Version = ""
|
||||
|
||||
|
@ -48,19 +45,20 @@ var Version = ""
|
|||
const ServerName = "gophish"
|
||||
|
||||
// LoadConfig loads the configuration from the specified filepath
|
||||
func LoadConfig(filepath string) error {
|
||||
func LoadConfig(filepath string) (*Config, error) {
|
||||
// Get the config file
|
||||
configFile, err := ioutil.ReadFile(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
err = json.Unmarshal(configFile, &Conf)
|
||||
config := &Config{}
|
||||
err = json.Unmarshal(configFile, config)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
// Choosing the migrations directory based on the database used.
|
||||
Conf.MigrationsPath = Conf.MigrationsPath + Conf.DBName
|
||||
config.MigrationsPath = config.MigrationsPath + config.DBName
|
||||
// Explicitly set the TestFlag to false to prevent config.json overrides
|
||||
Conf.TestFlag = false
|
||||
return nil
|
||||
config.TestFlag = false
|
||||
return config, nil
|
||||
}
|
||||
|
|
|
@ -48,18 +48,18 @@ func (s *ConfigSuite) TestLoadConfig() {
|
|||
_, err := s.ConfigFile.Write(validConfig)
|
||||
s.Nil(err)
|
||||
// Load the valid config
|
||||
err = LoadConfig(s.ConfigFile.Name())
|
||||
conf, err := LoadConfig(s.ConfigFile.Name())
|
||||
s.Nil(err)
|
||||
|
||||
expectedConfig := Config{}
|
||||
expectedConfig := &Config{}
|
||||
err = json.Unmarshal(validConfig, &expectedConfig)
|
||||
s.Nil(err)
|
||||
expectedConfig.MigrationsPath = expectedConfig.MigrationsPath + expectedConfig.DBName
|
||||
expectedConfig.TestFlag = false
|
||||
s.Equal(expectedConfig, Conf)
|
||||
s.Equal(expectedConfig, conf)
|
||||
|
||||
// Load an invalid config
|
||||
err = LoadConfig("bogusfile")
|
||||
conf, err = LoadConfig("bogusfile")
|
||||
s.NotNil(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -17,23 +17,14 @@ import (
|
|||
log "github.com/gophish/gophish/logger"
|
||||
"github.com/gophish/gophish/models"
|
||||
"github.com/gophish/gophish/util"
|
||||
"github.com/gophish/gophish/worker"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jordan-wright/email"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Worker is the worker that processes phishing events and updates campaigns.
|
||||
var Worker *worker.Worker
|
||||
|
||||
func init() {
|
||||
Worker = worker.New()
|
||||
go Worker.Start()
|
||||
}
|
||||
|
||||
// API (/api/reset) resets a user's API key
|
||||
func API_Reset(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Reset(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == "POST":
|
||||
u := ctx.Get(r, "user").(models.User)
|
||||
|
@ -49,7 +40,7 @@ func API_Reset(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// API_Campaigns returns a list of campaigns if requested via GET.
|
||||
// If requested via POST, API_Campaigns creates a new campaign and returns a reference to it.
|
||||
func API_Campaigns(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Campaigns(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == "GET":
|
||||
cs, err := models.GetCampaigns(ctx.Get(r, "user_id").(int64))
|
||||
|
@ -74,14 +65,14 @@ func API_Campaigns(w http.ResponseWriter, r *http.Request) {
|
|||
// If the campaign is scheduled to launch immediately, send it to the worker.
|
||||
// Otherwise, the worker will pick it up at the scheduled time
|
||||
if c.Status == models.CAMPAIGN_IN_PROGRESS {
|
||||
go Worker.LaunchCampaign(c)
|
||||
go as.worker.LaunchCampaign(c)
|
||||
}
|
||||
JSONResponse(w, c, http.StatusCreated)
|
||||
}
|
||||
}
|
||||
|
||||
// API_Campaigns_Summary returns the summary for the current user's campaigns
|
||||
func API_Campaigns_Summary(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Campaigns_Summary(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == "GET":
|
||||
cs, err := models.GetCampaignSummaries(ctx.Get(r, "user_id").(int64))
|
||||
|
@ -96,7 +87,7 @@ func API_Campaigns_Summary(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// API_Campaigns_Id returns details about the requested campaign. If the campaign is not
|
||||
// valid, API_Campaigns_Id returns null.
|
||||
func API_Campaigns_Id(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Campaigns_Id(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
id, _ := strconv.ParseInt(vars["id"], 0, 64)
|
||||
c, err := models.GetCampaign(id, ctx.Get(r, "user_id").(int64))
|
||||
|
@ -120,7 +111,7 @@ func API_Campaigns_Id(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// API_Campaigns_Id_Results returns just the results for a given campaign to
|
||||
// significantly reduce the information returned.
|
||||
func API_Campaigns_Id_Results(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Campaigns_Id_Results(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
id, _ := strconv.ParseInt(vars["id"], 0, 64)
|
||||
cr, err := models.GetCampaignResults(id, ctx.Get(r, "user_id").(int64))
|
||||
|
@ -136,7 +127,7 @@ func API_Campaigns_Id_Results(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// API_Campaigns_Id_Summary returns just the summary for a given campaign.
|
||||
func API_Campaign_Id_Summary(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Campaign_Id_Summary(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
id, _ := strconv.ParseInt(vars["id"], 0, 64)
|
||||
switch {
|
||||
|
@ -157,7 +148,7 @@ func API_Campaign_Id_Summary(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// API_Campaigns_Id_Complete effectively "ends" a campaign.
|
||||
// Future phishing emails clicked will return a simple "404" page.
|
||||
func API_Campaigns_Id_Complete(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Campaigns_Id_Complete(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
id, _ := strconv.ParseInt(vars["id"], 0, 64)
|
||||
switch {
|
||||
|
@ -173,7 +164,7 @@ func API_Campaigns_Id_Complete(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// API_Groups returns a list of groups if requested via GET.
|
||||
// If requested via POST, API_Groups creates a new group and returns a reference to it.
|
||||
func API_Groups(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Groups(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == "GET":
|
||||
gs, err := models.GetGroups(ctx.Get(r, "user_id").(int64))
|
||||
|
@ -208,7 +199,7 @@ func API_Groups(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// API_Groups_Summary returns a summary of the groups owned by the current user.
|
||||
func API_Groups_Summary(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Groups_Summary(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == "GET":
|
||||
gs, err := models.GetGroupSummaries(ctx.Get(r, "user_id").(int64))
|
||||
|
@ -223,7 +214,7 @@ func API_Groups_Summary(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// API_Groups_Id returns details about the requested group.
|
||||
// If the group is not valid, API_Groups_Id returns null.
|
||||
func API_Groups_Id(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Groups_Id(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
id, _ := strconv.ParseInt(vars["id"], 0, 64)
|
||||
g, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64))
|
||||
|
@ -261,7 +252,7 @@ func API_Groups_Id(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// API_Groups_Id_Summary returns a summary of the groups owned by the current user.
|
||||
func API_Groups_Id_Summary(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Groups_Id_Summary(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == "GET":
|
||||
vars := mux.Vars(r)
|
||||
|
@ -276,7 +267,7 @@ func API_Groups_Id_Summary(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// API_Templates handles the functionality for the /api/templates endpoint
|
||||
func API_Templates(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Templates(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == "GET":
|
||||
ts, err := models.GetTemplates(ctx.Get(r, "user_id").(int64))
|
||||
|
@ -319,7 +310,7 @@ func API_Templates(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// API_Templates_Id handles the functions for the /api/templates/:id endpoint
|
||||
func API_Templates_Id(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Templates_Id(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
id, _ := strconv.ParseInt(vars["id"], 0, 64)
|
||||
t, err := models.GetTemplate(id, ctx.Get(r, "user_id").(int64))
|
||||
|
@ -359,7 +350,7 @@ func API_Templates_Id(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// API_Pages handles requests for the /api/pages/ endpoint
|
||||
func API_Pages(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Pages(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == "GET":
|
||||
ps, err := models.GetPages(ctx.Get(r, "user_id").(int64))
|
||||
|
@ -396,7 +387,7 @@ func API_Pages(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// API_Pages_Id contains functions to handle the GET'ing, DELETE'ing, and PUT'ing
|
||||
// of a Page object
|
||||
func API_Pages_Id(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Pages_Id(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
id, _ := strconv.ParseInt(vars["id"], 0, 64)
|
||||
p, err := models.GetPage(id, ctx.Get(r, "user_id").(int64))
|
||||
|
@ -436,7 +427,7 @@ func API_Pages_Id(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// API_SMTP handles requests for the /api/smtp/ endpoint
|
||||
func API_SMTP(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_SMTP(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == "GET":
|
||||
ss, err := models.GetSMTPs(ctx.Get(r, "user_id").(int64))
|
||||
|
@ -473,7 +464,7 @@ func API_SMTP(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// API_SMTP_Id contains functions to handle the GET'ing, DELETE'ing, and PUT'ing
|
||||
// of a SMTP object
|
||||
func API_SMTP_Id(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_SMTP_Id(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
id, _ := strconv.ParseInt(vars["id"], 0, 64)
|
||||
s, err := models.GetSMTP(id, ctx.Get(r, "user_id").(int64))
|
||||
|
@ -518,7 +509,7 @@ func API_SMTP_Id(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// API_Import_Group imports a CSV of group members
|
||||
func API_Import_Group(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Import_Group(w http.ResponseWriter, r *http.Request) {
|
||||
ts, err := util.ParseCSV(r)
|
||||
if err != nil {
|
||||
JSONResponse(w, models.Response{Success: false, Message: "Error parsing CSV"}, http.StatusInternalServerError)
|
||||
|
@ -530,7 +521,7 @@ func API_Import_Group(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// API_Import_Email allows for the importing of email.
|
||||
// Returns a Message object
|
||||
func API_Import_Email(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Import_Email(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
JSONResponse(w, models.Response{Success: false, Message: "Method not allowed"}, http.StatusBadRequest)
|
||||
return
|
||||
|
@ -579,7 +570,7 @@ func API_Import_Email(w http.ResponseWriter, r *http.Request) {
|
|||
// API_Import_Site allows for the importing of HTML from a website
|
||||
// Without "include_resources" set, it will merely place a "base" tag
|
||||
// so that all resources can be loaded relative to the given URL.
|
||||
func API_Import_Site(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Import_Site(w http.ResponseWriter, r *http.Request) {
|
||||
cr := cloneRequest{}
|
||||
if r.Method != "POST" {
|
||||
JSONResponse(w, models.Response{Success: false, Message: "Method not allowed"}, http.StatusBadRequest)
|
||||
|
@ -637,7 +628,7 @@ func API_Import_Site(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// API_Send_Test_Email sends a test email using the template name
|
||||
// and Target given.
|
||||
func API_Send_Test_Email(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) API_Send_Test_Email(w http.ResponseWriter, r *http.Request) {
|
||||
s := &models.EmailRequest{
|
||||
ErrorChan: make(chan error),
|
||||
UserId: ctx.Get(r, "user_id").(int64),
|
||||
|
@ -735,7 +726,7 @@ func API_Send_Test_Email(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
// Send the test email
|
||||
err = Worker.SendTestEmail(s)
|
||||
err = as.worker.SendTestEmail(s)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusInternalServerError)
|
||||
|
|
|
@ -11,41 +11,42 @@ import (
|
|||
|
||||
"github.com/gophish/gophish/config"
|
||||
"github.com/gophish/gophish/models"
|
||||
"github.com/gorilla/handlers"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// ControllersSuite is a suite of tests to cover API related functions
|
||||
type ControllersSuite struct {
|
||||
suite.Suite
|
||||
ApiKey string
|
||||
ApiKey string
|
||||
config *config.Config
|
||||
adminServer *httptest.Server
|
||||
phishServer *httptest.Server
|
||||
}
|
||||
|
||||
// as is the Admin Server for our API calls
|
||||
var as *httptest.Server = httptest.NewUnstartedServer(handlers.CombinedLoggingHandler(os.Stdout, CreateAdminRouter()))
|
||||
|
||||
// ps is the Phishing Server
|
||||
var ps *httptest.Server = httptest.NewUnstartedServer(handlers.CombinedLoggingHandler(os.Stdout, CreatePhishingRouter()))
|
||||
|
||||
func (s *ControllersSuite) SetupSuite() {
|
||||
config.Conf.DBName = "sqlite3"
|
||||
config.Conf.DBPath = ":memory:"
|
||||
config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/"
|
||||
err := models.Setup()
|
||||
conf := &config.Config{
|
||||
DBName: "sqlite3",
|
||||
DBPath: ":memory:",
|
||||
MigrationsPath: "../db/db_sqlite3/migrations/",
|
||||
}
|
||||
err := models.Setup(conf)
|
||||
if err != nil {
|
||||
s.T().Fatalf("Failed creating database: %v", err)
|
||||
}
|
||||
s.config = conf
|
||||
s.Nil(err)
|
||||
// Setup the admin server for use in testing
|
||||
as.Config.Addr = config.Conf.AdminConf.ListenURL
|
||||
as.Start()
|
||||
s.adminServer = httptest.NewUnstartedServer(NewAdminServer(s.config.AdminConf).server.Handler)
|
||||
s.adminServer.Config.Addr = s.config.AdminConf.ListenURL
|
||||
s.adminServer.Start()
|
||||
// Get the API key to use for these tests
|
||||
u, err := models.GetUser(1)
|
||||
s.Nil(err)
|
||||
s.ApiKey = u.ApiKey
|
||||
// Start the phishing server
|
||||
ps.Config.Addr = config.Conf.PhishConf.ListenURL
|
||||
ps.Start()
|
||||
s.phishServer = httptest.NewUnstartedServer(NewPhishingServer(s.config.PhishConf).server.Handler)
|
||||
s.phishServer.Config.Addr = s.config.PhishConf.ListenURL
|
||||
s.phishServer.Start()
|
||||
// Move our cwd up to the project root for help with resolving
|
||||
// static assets
|
||||
err = os.Chdir("../")
|
||||
|
@ -103,21 +104,21 @@ func (s *ControllersSuite) SetupTest() {
|
|||
}
|
||||
|
||||
func (s *ControllersSuite) TestRequireAPIKey() {
|
||||
resp, err := http.Post(fmt.Sprintf("%s/api/import/site", as.URL), "application/json", nil)
|
||||
resp, err := http.Post(fmt.Sprintf("%s/api/import/site", s.adminServer.URL), "application/json", nil)
|
||||
s.Nil(err)
|
||||
defer resp.Body.Close()
|
||||
s.Equal(resp.StatusCode, http.StatusBadRequest)
|
||||
s.Equal(resp.StatusCode, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func (s *ControllersSuite) TestInvalidAPIKey() {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/api/groups/?api_key=%s", as.URL, "bogus-api-key"))
|
||||
resp, err := http.Get(fmt.Sprintf("%s/api/groups/?api_key=%s", s.adminServer.URL, "bogus-api-key"))
|
||||
s.Nil(err)
|
||||
defer resp.Body.Close()
|
||||
s.Equal(resp.StatusCode, http.StatusBadRequest)
|
||||
s.Equal(resp.StatusCode, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func (s *ControllersSuite) TestBearerToken() {
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/groups/", as.URL), nil)
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/groups/", s.adminServer.URL), nil)
|
||||
s.Nil(err)
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", s.ApiKey))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
|
@ -133,7 +134,7 @@ func (s *ControllersSuite) TestSiteImportBaseHref() {
|
|||
}))
|
||||
hr := fmt.Sprintf("<html><head><base href=\"%s\"/></head><body><img src=\"/test.png\"/>\n</body></html>", ts.URL)
|
||||
defer ts.Close()
|
||||
resp, err := http.Post(fmt.Sprintf("%s/api/import/site?api_key=%s", as.URL, s.ApiKey), "application/json",
|
||||
resp, err := http.Post(fmt.Sprintf("%s/api/import/site?api_key=%s", s.adminServer.URL, s.ApiKey), "application/json",
|
||||
bytes.NewBuffer([]byte(fmt.Sprintf(`
|
||||
{
|
||||
"url" : "%s",
|
||||
|
@ -150,8 +151,8 @@ func (s *ControllersSuite) TestSiteImportBaseHref() {
|
|||
|
||||
func (s *ControllersSuite) TearDownSuite() {
|
||||
// Tear down the admin and phishing servers
|
||||
as.Close()
|
||||
ps.Close()
|
||||
s.adminServer.Close()
|
||||
s.phishServer.Close()
|
||||
}
|
||||
|
||||
func TestControllerSuite(t *testing.T) {
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package controllers
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -8,11 +10,15 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/gophish/gophish/config"
|
||||
ctx "github.com/gophish/gophish/context"
|
||||
log "github.com/gophish/gophish/logger"
|
||||
"github.com/gophish/gophish/models"
|
||||
"github.com/gophish/gophish/util"
|
||||
"github.com/gorilla/handlers"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/jordan-wright/unindexed"
|
||||
)
|
||||
|
||||
// ErrInvalidRequest is thrown when a request with an invalid structure is
|
||||
|
@ -35,22 +41,91 @@ type TransparencyResponse struct {
|
|||
// to return a transparency response.
|
||||
const TransparencySuffix = "+"
|
||||
|
||||
// CreatePhishingRouter creates the router that handles phishing connections.
|
||||
func CreatePhishingRouter() http.Handler {
|
||||
router := mux.NewRouter()
|
||||
fileServer := http.FileServer(UnindexedFileSystem{http.Dir("./static/endpoint/")})
|
||||
router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", fileServer))
|
||||
router.HandleFunc("/track", PhishTracker)
|
||||
router.HandleFunc("/robots.txt", RobotsHandler)
|
||||
router.HandleFunc("/{path:.*}/track", PhishTracker)
|
||||
router.HandleFunc("/{path:.*}/report", PhishReporter)
|
||||
router.HandleFunc("/report", PhishReporter)
|
||||
router.HandleFunc("/{path:.*}", PhishHandler)
|
||||
return router
|
||||
// PhishingServerOption is a functional option that is used to configure the
|
||||
// the phishing server
|
||||
type PhishingServerOption func(*PhishingServer)
|
||||
|
||||
// PhishingServer is an HTTP server that implements the campaign event
|
||||
// handlers, such as email open tracking, click tracking, and more.
|
||||
type PhishingServer struct {
|
||||
server *http.Server
|
||||
config config.PhishServer
|
||||
contactAddress string
|
||||
}
|
||||
|
||||
// PhishTracker tracks emails as they are opened, updating the status for the given Result
|
||||
func PhishTracker(w http.ResponseWriter, r *http.Request) {
|
||||
// NewPhishingServer returns a new instance of the phishing server with
|
||||
// provided options applied.
|
||||
func NewPhishingServer(config config.PhishServer, options ...PhishingServerOption) *PhishingServer {
|
||||
defaultServer := &http.Server{
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
Addr: config.ListenURL,
|
||||
}
|
||||
ps := &PhishingServer{
|
||||
server: defaultServer,
|
||||
config: config,
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt(ps)
|
||||
}
|
||||
ps.registerRoutes()
|
||||
return ps
|
||||
}
|
||||
|
||||
// WithContactAddress sets the contact address used by the transparency
|
||||
// handlers
|
||||
func WithContactAddress(addr string) PhishingServerOption {
|
||||
return func(ps *PhishingServer) {
|
||||
ps.contactAddress = addr
|
||||
}
|
||||
}
|
||||
|
||||
// Start launches the phishing server, listening on the configured address.
|
||||
func (ps *PhishingServer) Start() error {
|
||||
if ps.config.UseTLS {
|
||||
err := util.CheckAndCreateSSL(ps.config.CertPath, ps.config.KeyPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return err
|
||||
}
|
||||
log.Infof("Starting phishing server at https://%s", ps.config.ListenURL)
|
||||
return ps.server.ListenAndServeTLS(ps.config.CertPath, ps.config.KeyPath)
|
||||
}
|
||||
// If TLS isn't configured, just listen on HTTP
|
||||
log.Infof("Starting phishing server at http://%s", ps.config.ListenURL)
|
||||
return ps.server.ListenAndServe()
|
||||
}
|
||||
|
||||
// Shutdown attempts to gracefully shutdown the server.
|
||||
func (ps *PhishingServer) Shutdown() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
return ps.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// CreatePhishingRouter creates the router that handles phishing connections.
|
||||
func (ps *PhishingServer) registerRoutes() {
|
||||
router := mux.NewRouter()
|
||||
fileServer := http.FileServer(unindexed.Dir("./static/endpoint/"))
|
||||
router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", fileServer))
|
||||
router.HandleFunc("/track", ps.TrackHandler)
|
||||
router.HandleFunc("/robots.txt", ps.RobotsHandler)
|
||||
router.HandleFunc("/{path:.*}/track", ps.TrackHandler)
|
||||
router.HandleFunc("/{path:.*}/report", ps.ReportHandler)
|
||||
router.HandleFunc("/report", ps.ReportHandler)
|
||||
router.HandleFunc("/{path:.*}", ps.PhishHandler)
|
||||
|
||||
// Setup GZIP compression
|
||||
gzipWrapper, _ := gziphandler.NewGzipLevelHandler(gzip.BestCompression)
|
||||
phishHandler := gzipWrapper(router)
|
||||
|
||||
// Setup logging
|
||||
phishHandler = handlers.CombinedLoggingHandler(log.Writer(), phishHandler)
|
||||
ps.server.Handler = router
|
||||
}
|
||||
|
||||
// TrackHandler tracks emails as they are opened, updating the status for the given Result
|
||||
func (ps *PhishingServer) TrackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
err, r := setupContext(r)
|
||||
if err != nil {
|
||||
// Log the error if it wasn't something we can safely ignore
|
||||
|
@ -71,7 +146,7 @@ func PhishTracker(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// Check for a transparency request
|
||||
if strings.HasSuffix(rid, TransparencySuffix) {
|
||||
TransparencyHandler(w, r)
|
||||
ps.TransparencyHandler(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -82,8 +157,8 @@ func PhishTracker(w http.ResponseWriter, r *http.Request) {
|
|||
http.ServeFile(w, r, "static/images/pixel.png")
|
||||
}
|
||||
|
||||
// PhishReporter tracks emails as they are reported, updating the status for the given Result
|
||||
func PhishReporter(w http.ResponseWriter, r *http.Request) {
|
||||
// ReportHandler tracks emails as they are reported, updating the status for the given Result
|
||||
func (ps *PhishingServer) ReportHandler(w http.ResponseWriter, r *http.Request) {
|
||||
err, r := setupContext(r)
|
||||
if err != nil {
|
||||
// Log the error if it wasn't something we can safely ignore
|
||||
|
@ -104,7 +179,7 @@ func PhishReporter(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// Check for a transparency request
|
||||
if strings.HasSuffix(rid, TransparencySuffix) {
|
||||
TransparencyHandler(w, r)
|
||||
ps.TransparencyHandler(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -117,7 +192,7 @@ func PhishReporter(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// PhishHandler handles incoming client connections and registers the associated actions performed
|
||||
// (such as clicked link, etc.)
|
||||
func PhishHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func (ps *PhishingServer) PhishHandler(w http.ResponseWriter, r *http.Request) {
|
||||
err, r := setupContext(r)
|
||||
if err != nil {
|
||||
// Log the error if it wasn't something we can safely ignore
|
||||
|
@ -152,7 +227,7 @@ func PhishHandler(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// Check for a transparency request
|
||||
if strings.HasSuffix(rid, TransparencySuffix) {
|
||||
TransparencyHandler(w, r)
|
||||
ps.TransparencyHandler(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -211,23 +286,24 @@ func renderPhishResponse(w http.ResponseWriter, r *http.Request, ptx models.Phis
|
|||
}
|
||||
|
||||
// RobotsHandler prevents search engines, etc. from indexing phishing materials
|
||||
func RobotsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func (ps *PhishingServer) RobotsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "User-agent: *\nDisallow: /")
|
||||
}
|
||||
|
||||
// TransparencyHandler returns a TransparencyResponse for the provided result
|
||||
// and campaign.
|
||||
func TransparencyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func (ps *PhishingServer) TransparencyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
rs := ctx.Get(r, "result").(models.Result)
|
||||
tr := &TransparencyResponse{
|
||||
Server: config.ServerName,
|
||||
SendDate: rs.SendDate,
|
||||
ContactAddress: config.Conf.ContactAddress,
|
||||
ContactAddress: ps.contactAddress,
|
||||
}
|
||||
JSONResponse(w, tr, http.StatusOK)
|
||||
}
|
||||
|
||||
// setupContext handles some of the administrative work around receiving a new request, such as checking the result ID, the campaign, etc.
|
||||
// setupContext handles some of the administrative work around receiving a new
|
||||
// request, such as checking the result ID, the campaign, etc.
|
||||
func setupContext(r *http.Request) (error, *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
|
|
|
@ -38,7 +38,7 @@ func (s *ControllersSuite) getFirstEmailRequest() models.EmailRequest {
|
|||
}
|
||||
|
||||
func (s *ControllersSuite) openEmail(rid string) {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", ps.URL, models.RecipientParameter, rid))
|
||||
resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
|
||||
s.Nil(err)
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
|
@ -49,19 +49,19 @@ func (s *ControllersSuite) openEmail(rid string) {
|
|||
}
|
||||
|
||||
func (s *ControllersSuite) reportedEmail(rid string) {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", ps.URL, models.RecipientParameter, rid))
|
||||
resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
|
||||
s.Nil(err)
|
||||
s.Equal(resp.StatusCode, http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (s *ControllersSuite) reportEmail404(rid string) {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", ps.URL, models.RecipientParameter, rid))
|
||||
resp, err := http.Get(fmt.Sprintf("%s/report?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
|
||||
s.Nil(err)
|
||||
s.Equal(resp.StatusCode, http.StatusNotFound)
|
||||
}
|
||||
|
||||
func (s *ControllersSuite) openEmail404(rid string) {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", ps.URL, models.RecipientParameter, rid))
|
||||
resp, err := http.Get(fmt.Sprintf("%s/track?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
|
||||
s.Nil(err)
|
||||
defer resp.Body.Close()
|
||||
s.Nil(err)
|
||||
|
@ -69,7 +69,7 @@ func (s *ControllersSuite) openEmail404(rid string) {
|
|||
}
|
||||
|
||||
func (s *ControllersSuite) clickLink(rid string, expectedHTML string) {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", ps.URL, models.RecipientParameter, rid))
|
||||
resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
|
||||
s.Nil(err)
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
|
@ -79,7 +79,7 @@ func (s *ControllersSuite) clickLink(rid string, expectedHTML string) {
|
|||
}
|
||||
|
||||
func (s *ControllersSuite) clickLink404(rid string) {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", ps.URL, models.RecipientParameter, rid))
|
||||
resp, err := http.Get(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, rid))
|
||||
s.Nil(err)
|
||||
defer resp.Body.Close()
|
||||
s.Nil(err)
|
||||
|
@ -87,14 +87,14 @@ func (s *ControllersSuite) clickLink404(rid string) {
|
|||
}
|
||||
|
||||
func (s *ControllersSuite) transparencyRequest(r models.Result, rid, path string) {
|
||||
resp, err := http.Get(fmt.Sprintf("%s%s?%s=%s", ps.URL, path, models.RecipientParameter, rid))
|
||||
resp, err := http.Get(fmt.Sprintf("%s%s?%s=%s", s.phishServer.URL, path, models.RecipientParameter, rid))
|
||||
s.Nil(err)
|
||||
defer resp.Body.Close()
|
||||
s.Equal(resp.StatusCode, http.StatusOK)
|
||||
tr := &TransparencyResponse{}
|
||||
err = json.NewDecoder(resp.Body).Decode(tr)
|
||||
s.Nil(err)
|
||||
s.Equal(tr.ContactAddress, config.Conf.ContactAddress)
|
||||
s.Equal(tr.ContactAddress, s.config.ContactAddress)
|
||||
s.Equal(tr.SendDate, r.SendDate)
|
||||
s.Equal(tr.Server, config.ServerName)
|
||||
}
|
||||
|
@ -146,11 +146,11 @@ func (s *ControllersSuite) TestClickedPhishingLinkAfterOpen() {
|
|||
}
|
||||
|
||||
func (s *ControllersSuite) TestNoRecipientID() {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/track", ps.URL))
|
||||
resp, err := http.Get(fmt.Sprintf("%s/track", s.phishServer.URL))
|
||||
s.Nil(err)
|
||||
s.Equal(resp.StatusCode, http.StatusNotFound)
|
||||
|
||||
resp, err = http.Get(ps.URL)
|
||||
resp, err = http.Get(s.phishServer.URL)
|
||||
s.Nil(err)
|
||||
s.Equal(resp.StatusCode, http.StatusNotFound)
|
||||
}
|
||||
|
@ -183,7 +183,7 @@ func (s *ControllersSuite) TestCompletedCampaignClick() {
|
|||
|
||||
func (s *ControllersSuite) TestRobotsHandler() {
|
||||
expected := []byte("User-agent: *\nDisallow: /\n")
|
||||
resp, err := http.Get(fmt.Sprintf("%s/robots.txt", ps.URL))
|
||||
resp, err := http.Get(fmt.Sprintf("%s/robots.txt", s.phishServer.URL))
|
||||
s.Nil(err)
|
||||
s.Equal(resp.StatusCode, http.StatusOK)
|
||||
defer resp.Body.Close()
|
||||
|
@ -259,7 +259,7 @@ func (s *ControllersSuite) TestRedirectTemplating() {
|
|||
},
|
||||
}
|
||||
result := campaign.Results[0]
|
||||
resp, err := client.PostForm(fmt.Sprintf("%s/?%s=%s", ps.URL, models.RecipientParameter, result.RId), url.Values{"username": {"test"}, "password": {"test"}})
|
||||
resp, err := client.PostForm(fmt.Sprintf("%s/?%s=%s", s.phishServer.URL, models.RecipientParameter, result.RId), url.Values{"username": {"test"}, "password": {"test"}})
|
||||
s.Nil(err)
|
||||
defer resp.Body.Close()
|
||||
s.Equal(http.StatusFound, resp.StatusCode)
|
||||
|
|
|
@ -1,72 +1,153 @@
|
|||
package controllers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/gophish/gophish/auth"
|
||||
"github.com/gophish/gophish/config"
|
||||
ctx "github.com/gophish/gophish/context"
|
||||
log "github.com/gophish/gophish/logger"
|
||||
mid "github.com/gophish/gophish/middleware"
|
||||
"github.com/gophish/gophish/models"
|
||||
"github.com/gophish/gophish/util"
|
||||
"github.com/gophish/gophish/worker"
|
||||
"github.com/gorilla/csrf"
|
||||
"github.com/gorilla/handlers"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/jordan-wright/unindexed"
|
||||
)
|
||||
|
||||
// CreateAdminRouter creates the routes for handling requests to the web interface.
|
||||
// AdminServerOption is a functional option that is used to configure the
|
||||
// admin server
|
||||
type AdminServerOption func(*AdminServer)
|
||||
|
||||
// AdminServer is an HTTP server that implements the administrative Gophish
|
||||
// handlers, including the dashboard and REST API.
|
||||
type AdminServer struct {
|
||||
server *http.Server
|
||||
worker worker.Worker
|
||||
config config.AdminServer
|
||||
}
|
||||
|
||||
// WithWorker is an option that sets the background worker.
|
||||
func WithWorker(w worker.Worker) AdminServerOption {
|
||||
return func(as *AdminServer) {
|
||||
as.worker = w
|
||||
}
|
||||
}
|
||||
|
||||
// NewAdminServer returns a new instance of the AdminServer with the
|
||||
// provided config and options applied.
|
||||
func NewAdminServer(config config.AdminServer, options ...AdminServerOption) *AdminServer {
|
||||
defaultWorker, _ := worker.New()
|
||||
defaultServer := &http.Server{
|
||||
ReadTimeout: 10 * time.Second,
|
||||
Addr: config.ListenURL,
|
||||
}
|
||||
as := &AdminServer{
|
||||
worker: defaultWorker,
|
||||
server: defaultServer,
|
||||
config: config,
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt(as)
|
||||
}
|
||||
as.registerRoutes()
|
||||
return as
|
||||
}
|
||||
|
||||
// Start launches the admin server, listening on the configured address.
|
||||
func (as *AdminServer) Start() error {
|
||||
if as.worker != nil {
|
||||
go as.worker.Start()
|
||||
}
|
||||
if as.config.UseTLS {
|
||||
err := util.CheckAndCreateSSL(as.config.CertPath, as.config.KeyPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return err
|
||||
}
|
||||
log.Infof("Starting admin server at https://%s", as.config.ListenURL)
|
||||
return as.server.ListenAndServeTLS(as.config.CertPath, as.config.KeyPath)
|
||||
}
|
||||
// If TLS isn't configured, just listen on HTTP
|
||||
log.Infof("Starting admin server at http://%s", as.config.ListenURL)
|
||||
return as.server.ListenAndServe()
|
||||
}
|
||||
|
||||
// Shutdown attempts to gracefully shutdown the server.
|
||||
func (as *AdminServer) Shutdown() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
return as.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// SetupAdminRoutes creates the routes for handling requests to the web interface.
|
||||
// This function returns an http.Handler to be used in http.ListenAndServe().
|
||||
func CreateAdminRouter() http.Handler {
|
||||
func (as *AdminServer) registerRoutes() {
|
||||
router := mux.NewRouter()
|
||||
// Base Front-end routes
|
||||
router.HandleFunc("/", Use(Base, mid.RequireLogin))
|
||||
router.HandleFunc("/login", Login)
|
||||
router.HandleFunc("/logout", Use(Logout, mid.RequireLogin))
|
||||
router.HandleFunc("/campaigns", Use(Campaigns, mid.RequireLogin))
|
||||
router.HandleFunc("/campaigns/{id:[0-9]+}", Use(CampaignID, mid.RequireLogin))
|
||||
router.HandleFunc("/templates", Use(Templates, mid.RequireLogin))
|
||||
router.HandleFunc("/users", Use(Users, mid.RequireLogin))
|
||||
router.HandleFunc("/landing_pages", Use(LandingPages, mid.RequireLogin))
|
||||
router.HandleFunc("/sending_profiles", Use(SendingProfiles, mid.RequireLogin))
|
||||
router.HandleFunc("/register", Use(Register, mid.RequireLogin))
|
||||
router.HandleFunc("/settings", Use(Settings, mid.RequireLogin))
|
||||
router.HandleFunc("/", Use(as.Base, mid.RequireLogin))
|
||||
router.HandleFunc("/login", as.Login)
|
||||
router.HandleFunc("/logout", Use(as.Logout, mid.RequireLogin))
|
||||
router.HandleFunc("/campaigns", Use(as.Campaigns, mid.RequireLogin))
|
||||
router.HandleFunc("/campaigns/{id:[0-9]+}", Use(as.CampaignID, mid.RequireLogin))
|
||||
router.HandleFunc("/templates", Use(as.Templates, mid.RequireLogin))
|
||||
router.HandleFunc("/users", Use(as.Users, mid.RequireLogin))
|
||||
router.HandleFunc("/landing_pages", Use(as.LandingPages, mid.RequireLogin))
|
||||
router.HandleFunc("/sending_profiles", Use(as.SendingProfiles, mid.RequireLogin))
|
||||
router.HandleFunc("/register", Use(as.Register, mid.RequireLogin))
|
||||
router.HandleFunc("/settings", Use(as.Settings, mid.RequireLogin))
|
||||
// Create the API routes
|
||||
api := router.PathPrefix("/api").Subrouter()
|
||||
api = api.StrictSlash(true)
|
||||
api.HandleFunc("/reset", Use(API_Reset, mid.RequireAPIKey))
|
||||
api.HandleFunc("/campaigns/", Use(API_Campaigns, mid.RequireAPIKey))
|
||||
api.HandleFunc("/campaigns/summary", Use(API_Campaigns_Summary, mid.RequireAPIKey))
|
||||
api.HandleFunc("/campaigns/{id:[0-9]+}", Use(API_Campaigns_Id, mid.RequireAPIKey))
|
||||
api.HandleFunc("/campaigns/{id:[0-9]+}/results", Use(API_Campaigns_Id_Results, mid.RequireAPIKey))
|
||||
api.HandleFunc("/campaigns/{id:[0-9]+}/summary", Use(API_Campaign_Id_Summary, mid.RequireAPIKey))
|
||||
api.HandleFunc("/campaigns/{id:[0-9]+}/complete", Use(API_Campaigns_Id_Complete, mid.RequireAPIKey))
|
||||
api.HandleFunc("/groups/", Use(API_Groups, mid.RequireAPIKey))
|
||||
api.HandleFunc("/groups/summary", Use(API_Groups_Summary, mid.RequireAPIKey))
|
||||
api.HandleFunc("/groups/{id:[0-9]+}", Use(API_Groups_Id, mid.RequireAPIKey))
|
||||
api.HandleFunc("/groups/{id:[0-9]+}/summary", Use(API_Groups_Id_Summary, mid.RequireAPIKey))
|
||||
api.HandleFunc("/templates/", Use(API_Templates, mid.RequireAPIKey))
|
||||
api.HandleFunc("/templates/{id:[0-9]+}", Use(API_Templates_Id, mid.RequireAPIKey))
|
||||
api.HandleFunc("/pages/", Use(API_Pages, mid.RequireAPIKey))
|
||||
api.HandleFunc("/pages/{id:[0-9]+}", Use(API_Pages_Id, mid.RequireAPIKey))
|
||||
api.HandleFunc("/smtp/", Use(API_SMTP, mid.RequireAPIKey))
|
||||
api.HandleFunc("/smtp/{id:[0-9]+}", Use(API_SMTP_Id, mid.RequireAPIKey))
|
||||
api.HandleFunc("/util/send_test_email", Use(API_Send_Test_Email, mid.RequireAPIKey))
|
||||
api.HandleFunc("/import/group", Use(API_Import_Group, mid.RequireAPIKey))
|
||||
api.HandleFunc("/import/email", Use(API_Import_Email, mid.RequireAPIKey))
|
||||
api.HandleFunc("/import/site", Use(API_Import_Site, mid.RequireAPIKey))
|
||||
api.Use(mid.RequireAPIKey)
|
||||
api.HandleFunc("/reset", as.API_Reset)
|
||||
api.HandleFunc("/campaigns/", as.API_Campaigns)
|
||||
api.HandleFunc("/campaigns/summary", as.API_Campaigns_Summary)
|
||||
api.HandleFunc("/campaigns/{id:[0-9]+}", as.API_Campaigns_Id)
|
||||
api.HandleFunc("/campaigns/{id:[0-9]+}/results", as.API_Campaigns_Id_Results)
|
||||
api.HandleFunc("/campaigns/{id:[0-9]+}/summary", as.API_Campaign_Id_Summary)
|
||||
api.HandleFunc("/campaigns/{id:[0-9]+}/complete", as.API_Campaigns_Id_Complete)
|
||||
api.HandleFunc("/groups/", as.API_Groups)
|
||||
api.HandleFunc("/groups/summary", as.API_Groups_Summary)
|
||||
api.HandleFunc("/groups/{id:[0-9]+}", as.API_Groups_Id)
|
||||
api.HandleFunc("/groups/{id:[0-9]+}/summary", as.API_Groups_Id_Summary)
|
||||
api.HandleFunc("/templates/", as.API_Templates)
|
||||
api.HandleFunc("/templates/{id:[0-9]+}", as.API_Templates_Id)
|
||||
api.HandleFunc("/pages/", as.API_Pages)
|
||||
api.HandleFunc("/pages/{id:[0-9]+}", as.API_Pages_Id)
|
||||
api.HandleFunc("/smtp/", as.API_SMTP)
|
||||
api.HandleFunc("/smtp/{id:[0-9]+}", as.API_SMTP_Id)
|
||||
api.HandleFunc("/util/send_test_email", as.API_Send_Test_Email)
|
||||
api.HandleFunc("/import/group", as.API_Import_Group)
|
||||
api.HandleFunc("/import/email", as.API_Import_Email)
|
||||
api.HandleFunc("/import/site", as.API_Import_Site)
|
||||
|
||||
// Setup static file serving
|
||||
router.PathPrefix("/").Handler(http.FileServer(UnindexedFileSystem{http.Dir("./static/")}))
|
||||
router.PathPrefix("/").Handler(http.FileServer(unindexed.Dir("./static/")))
|
||||
|
||||
// Setup CSRF Protection
|
||||
csrfHandler := csrf.Protect([]byte(auth.GenerateSecureKey()),
|
||||
csrf.FieldName("csrf_token"),
|
||||
csrf.Secure(config.Conf.AdminConf.UseTLS))
|
||||
csrfRouter := csrfHandler(router)
|
||||
return Use(csrfRouter.ServeHTTP, mid.CSRFExceptions, mid.GetContext)
|
||||
csrf.Secure(as.config.UseTLS))
|
||||
adminHandler := csrfHandler(router)
|
||||
adminHandler = Use(adminHandler.ServeHTTP, mid.CSRFExceptions, mid.GetContext)
|
||||
|
||||
// Setup GZIP compression
|
||||
gzipWrapper, _ := gziphandler.NewGzipLevelHandler(gzip.BestCompression)
|
||||
adminHandler = gzipWrapper(adminHandler)
|
||||
|
||||
// Setup logging
|
||||
adminHandler = handlers.CombinedLoggingHandler(log.Writer(), adminHandler)
|
||||
as.server.Handler = adminHandler
|
||||
}
|
||||
|
||||
// Use allows us to stack middleware to process the request
|
||||
|
@ -78,16 +159,29 @@ func Use(handler http.HandlerFunc, mid ...func(http.Handler) http.HandlerFunc) h
|
|||
return handler
|
||||
}
|
||||
|
||||
type templateParams struct {
|
||||
Title string
|
||||
Flashes []interface{}
|
||||
User models.User
|
||||
Token string
|
||||
Version string
|
||||
}
|
||||
|
||||
// newTemplateParams returns the default template parameters for a user and
|
||||
// the CSRF token.
|
||||
func newTemplateParams(r *http.Request) templateParams {
|
||||
return templateParams{
|
||||
Token: csrf.Token(r),
|
||||
User: ctx.Get(r, "user").(models.User),
|
||||
Version: config.Version,
|
||||
}
|
||||
}
|
||||
|
||||
// Register creates a new user
|
||||
func Register(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) Register(w http.ResponseWriter, r *http.Request) {
|
||||
// If it is a post request, attempt to register the account
|
||||
// Now that we are all registered, we can log the user in
|
||||
params := struct {
|
||||
Title string
|
||||
Flashes []interface{}
|
||||
User models.User
|
||||
Token string
|
||||
}{Title: "Register", Token: csrf.Token(r)}
|
||||
params := templateParams{Title: "Register", Token: csrf.Token(r)}
|
||||
session := ctx.Get(r, "session").(*sessions.Session)
|
||||
switch {
|
||||
case r.Method == "GET":
|
||||
|
@ -120,99 +214,60 @@ func Register(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Base handles the default path and template execution
|
||||
func Base(w http.ResponseWriter, r *http.Request) {
|
||||
params := struct {
|
||||
User models.User
|
||||
Title string
|
||||
Flashes []interface{}
|
||||
Token string
|
||||
}{Title: "Dashboard", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
|
||||
func (as *AdminServer) Base(w http.ResponseWriter, r *http.Request) {
|
||||
params := newTemplateParams(r)
|
||||
params.Title = "Dashboard"
|
||||
getTemplate(w, "dashboard").ExecuteTemplate(w, "base", params)
|
||||
}
|
||||
|
||||
// Campaigns handles the default path and template execution
|
||||
func Campaigns(w http.ResponseWriter, r *http.Request) {
|
||||
// Example of using session - will be removed.
|
||||
params := struct {
|
||||
User models.User
|
||||
Title string
|
||||
Flashes []interface{}
|
||||
Token string
|
||||
}{Title: "Campaigns", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
|
||||
func (as *AdminServer) Campaigns(w http.ResponseWriter, r *http.Request) {
|
||||
params := newTemplateParams(r)
|
||||
params.Title = "Campaigns"
|
||||
getTemplate(w, "campaigns").ExecuteTemplate(w, "base", params)
|
||||
}
|
||||
|
||||
// CampaignID handles the default path and template execution
|
||||
func CampaignID(w http.ResponseWriter, r *http.Request) {
|
||||
// Example of using session - will be removed.
|
||||
params := struct {
|
||||
User models.User
|
||||
Title string
|
||||
Flashes []interface{}
|
||||
Token string
|
||||
}{Title: "Campaign Results", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
|
||||
func (as *AdminServer) CampaignID(w http.ResponseWriter, r *http.Request) {
|
||||
params := newTemplateParams(r)
|
||||
params.Title = "Campaign Results"
|
||||
getTemplate(w, "campaign_results").ExecuteTemplate(w, "base", params)
|
||||
}
|
||||
|
||||
// Templates handles the default path and template execution
|
||||
func Templates(w http.ResponseWriter, r *http.Request) {
|
||||
// Example of using session - will be removed.
|
||||
params := struct {
|
||||
User models.User
|
||||
Title string
|
||||
Flashes []interface{}
|
||||
Token string
|
||||
}{Title: "Email Templates", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
|
||||
func (as *AdminServer) Templates(w http.ResponseWriter, r *http.Request) {
|
||||
params := newTemplateParams(r)
|
||||
params.Title = "Email Templates"
|
||||
getTemplate(w, "templates").ExecuteTemplate(w, "base", params)
|
||||
}
|
||||
|
||||
// Users handles the default path and template execution
|
||||
func Users(w http.ResponseWriter, r *http.Request) {
|
||||
// Example of using session - will be removed.
|
||||
params := struct {
|
||||
User models.User
|
||||
Title string
|
||||
Flashes []interface{}
|
||||
Token string
|
||||
}{Title: "Users & Groups", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
|
||||
func (as *AdminServer) Users(w http.ResponseWriter, r *http.Request) {
|
||||
params := newTemplateParams(r)
|
||||
params.Title = "Users & Groups"
|
||||
getTemplate(w, "users").ExecuteTemplate(w, "base", params)
|
||||
}
|
||||
|
||||
// LandingPages handles the default path and template execution
|
||||
func LandingPages(w http.ResponseWriter, r *http.Request) {
|
||||
// Example of using session - will be removed.
|
||||
params := struct {
|
||||
User models.User
|
||||
Title string
|
||||
Flashes []interface{}
|
||||
Token string
|
||||
}{Title: "Landing Pages", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
|
||||
func (as *AdminServer) LandingPages(w http.ResponseWriter, r *http.Request) {
|
||||
params := newTemplateParams(r)
|
||||
params.Title = "Landing Pages"
|
||||
getTemplate(w, "landing_pages").ExecuteTemplate(w, "base", params)
|
||||
}
|
||||
|
||||
// SendingProfiles handles the default path and template execution
|
||||
func SendingProfiles(w http.ResponseWriter, r *http.Request) {
|
||||
// Example of using session - will be removed.
|
||||
params := struct {
|
||||
User models.User
|
||||
Title string
|
||||
Flashes []interface{}
|
||||
Token string
|
||||
}{Title: "Sending Profiles", User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
|
||||
func (as *AdminServer) SendingProfiles(w http.ResponseWriter, r *http.Request) {
|
||||
params := newTemplateParams(r)
|
||||
params.Title = "Sending Profiles"
|
||||
getTemplate(w, "sending_profiles").ExecuteTemplate(w, "base", params)
|
||||
}
|
||||
|
||||
// Settings handles the changing of settings
|
||||
func Settings(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) Settings(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == "GET":
|
||||
params := struct {
|
||||
User models.User
|
||||
Title string
|
||||
Flashes []interface{}
|
||||
Token string
|
||||
Version string
|
||||
}{Title: "Settings", Version: config.Version, User: ctx.Get(r, "user").(models.User), Token: csrf.Token(r)}
|
||||
params := newTemplateParams(r)
|
||||
params.Title = "Settings"
|
||||
getTemplate(w, "settings").ExecuteTemplate(w, "base", params)
|
||||
case r.Method == "POST":
|
||||
err := auth.ChangePassword(r)
|
||||
|
@ -235,7 +290,7 @@ func Settings(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// Login handles the authentication flow for a user. If credentials are valid,
|
||||
// a session is created
|
||||
func Login(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) Login(w http.ResponseWriter, r *http.Request) {
|
||||
params := struct {
|
||||
User models.User
|
||||
Title string
|
||||
|
@ -289,7 +344,7 @@ func Login(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Logout destroys the current user session
|
||||
func Logout(w http.ResponseWriter, r *http.Request) {
|
||||
func (as *AdminServer) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
session := ctx.Get(r, "session").(*sessions.Session)
|
||||
delete(session.Values, "id")
|
||||
Flash(w, r, "success", "You have successfully logged out")
|
||||
|
@ -297,28 +352,6 @@ func Logout(w http.ResponseWriter, r *http.Request) {
|
|||
http.Redirect(w, r, "/login", 302)
|
||||
}
|
||||
|
||||
// Preview allows for the viewing of page html in a separate browser window
|
||||
func Preview(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%s", r.FormValue("html"))
|
||||
}
|
||||
|
||||
// Clone takes a URL as a POST parameter and returns the site HTML
|
||||
func Clone(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if url, ok := vars["url"]; ok {
|
||||
log.Error(url)
|
||||
}
|
||||
http.Error(w, "No URL given.", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func getTemplate(w http.ResponseWriter, tmpl string) *template.Template {
|
||||
templates := template.New("template")
|
||||
_, err := templates.ParseFiles("templates/base.html", "templates/"+tmpl+".html", "templates/flashes.html")
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
)
|
||||
|
||||
func (s *ControllersSuite) TestLoginCSRF() {
|
||||
resp, err := http.PostForm(fmt.Sprintf("%s/login", as.URL),
|
||||
resp, err := http.PostForm(fmt.Sprintf("%s/login", s.adminServer.URL),
|
||||
url.Values{
|
||||
"username": {"admin"},
|
||||
"password": {"gophish"},
|
||||
|
@ -21,7 +21,7 @@ func (s *ControllersSuite) TestLoginCSRF() {
|
|||
}
|
||||
|
||||
func (s *ControllersSuite) TestInvalidCredentials() {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/login", as.URL))
|
||||
resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL))
|
||||
s.Equal(err, nil)
|
||||
s.Equal(resp.StatusCode, http.StatusOK)
|
||||
|
||||
|
@ -32,7 +32,7 @@ func (s *ControllersSuite) TestInvalidCredentials() {
|
|||
s.Equal(ok, true)
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", as.URL), strings.NewReader(url.Values{
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", s.adminServer.URL), strings.NewReader(url.Values{
|
||||
"username": {"admin"},
|
||||
"password": {"invalid"},
|
||||
"csrf_token": {token},
|
||||
|
@ -48,7 +48,7 @@ func (s *ControllersSuite) TestInvalidCredentials() {
|
|||
}
|
||||
|
||||
func (s *ControllersSuite) TestSuccessfulLogin() {
|
||||
resp, err := http.Get(fmt.Sprintf("%s/login", as.URL))
|
||||
resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL))
|
||||
s.Equal(err, nil)
|
||||
s.Equal(resp.StatusCode, http.StatusOK)
|
||||
|
||||
|
@ -59,7 +59,7 @@ func (s *ControllersSuite) TestSuccessfulLogin() {
|
|||
s.Equal(ok, true)
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", as.URL), strings.NewReader(url.Values{
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("%s/login", s.adminServer.URL), strings.NewReader(url.Values{
|
||||
"username": {"admin"},
|
||||
"password": {"gophish"},
|
||||
"csrf_token": {token},
|
||||
|
@ -76,7 +76,7 @@ func (s *ControllersSuite) TestSuccessfulLogin() {
|
|||
|
||||
func (s *ControllersSuite) TestSuccessfulRedirect() {
|
||||
next := "/campaigns"
|
||||
resp, err := http.Get(fmt.Sprintf("%s/login", as.URL))
|
||||
resp, err := http.Get(fmt.Sprintf("%s/login", s.adminServer.URL))
|
||||
s.Equal(err, nil)
|
||||
s.Equal(resp.StatusCode, http.StatusOK)
|
||||
|
||||
|
@ -91,7 +91,7 @@ func (s *ControllersSuite) TestSuccessfulRedirect() {
|
|||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("%s/login?next=%s", as.URL, next), strings.NewReader(url.Values{
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("%s/login?next=%s", s.adminServer.URL, next), strings.NewReader(url.Values{
|
||||
"username": {"admin"},
|
||||
"password": {"gophish"},
|
||||
"csrf_token": {token},
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
package controllers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// UnindexedFileSystem is an implementation of a standard http.FileSystem
|
||||
// without the ability to list files in the directory.
|
||||
// This implementation is largely inspired by
|
||||
// https://www.alexedwards.net/blog/disable-http-fileserver-directory-listings
|
||||
type UnindexedFileSystem struct {
|
||||
fs http.FileSystem
|
||||
}
|
||||
|
||||
// Open returns a file from the static directory. If the requested path ends
|
||||
// with a slash, there is a check for an index.html file. If none exists, then
|
||||
// an error is returned.
|
||||
func (ufs UnindexedFileSystem) Open(name string) (http.File, error) {
|
||||
f, err := ufs.fs.Open(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s, err := f.Stat()
|
||||
if s.IsDir() {
|
||||
index := strings.TrimSuffix(name, "/") + "/index.html"
|
||||
indexFile, err := ufs.fs.Open(index)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return indexFile, nil
|
||||
}
|
||||
return f, nil
|
||||
}
|
|
@ -1,81 +0,0 @@
|
|||
package controllers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
var fileContent = []byte("Hello world")
|
||||
|
||||
func mustRemoveAll(dir string) {
|
||||
err := os.RemoveAll(dir)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func createTestFile(dir, filename string) error {
|
||||
return ioutil.WriteFile(filepath.Join(dir, filename), fileContent, 0644)
|
||||
}
|
||||
|
||||
func (s *ControllersSuite) TestGetStaticFile() {
|
||||
dir, err := ioutil.TempDir("static/endpoint", "test-")
|
||||
tempFolder := filepath.Base(dir)
|
||||
|
||||
s.Nil(err)
|
||||
defer mustRemoveAll(dir)
|
||||
|
||||
err = createTestFile(dir, "foo.txt")
|
||||
s.Nil(nil, err)
|
||||
|
||||
resp, err := http.Get(fmt.Sprintf("%s/static/%s/foo.txt", ps.URL, tempFolder))
|
||||
s.Nil(err)
|
||||
|
||||
defer resp.Body.Close()
|
||||
got, err := ioutil.ReadAll(resp.Body)
|
||||
s.Nil(err)
|
||||
|
||||
s.Equal(bytes.Compare(fileContent, got), 0, fmt.Sprintf("Got %s", got))
|
||||
}
|
||||
|
||||
func (s *ControllersSuite) TestStaticFileListing() {
|
||||
dir, err := ioutil.TempDir("static/endpoint", "test-")
|
||||
tempFolder := filepath.Base(dir)
|
||||
|
||||
s.Nil(err)
|
||||
defer mustRemoveAll(dir)
|
||||
|
||||
err = createTestFile(dir, "foo.txt")
|
||||
s.Nil(nil, err)
|
||||
|
||||
resp, err := http.Get(fmt.Sprintf("%s/static/%s/", ps.URL, tempFolder))
|
||||
s.Nil(err)
|
||||
|
||||
defer resp.Body.Close()
|
||||
s.Nil(err)
|
||||
s.Equal(resp.StatusCode, http.StatusNotFound)
|
||||
}
|
||||
|
||||
func (s *ControllersSuite) TestStaticIndex() {
|
||||
dir, err := ioutil.TempDir("static/endpoint", "test-")
|
||||
tempFolder := filepath.Base(dir)
|
||||
|
||||
s.Nil(err)
|
||||
defer mustRemoveAll(dir)
|
||||
|
||||
err = createTestFile(dir, "index.html")
|
||||
s.Nil(nil, err)
|
||||
|
||||
resp, err := http.Get(fmt.Sprintf("%s/static/%s/", ps.URL, tempFolder))
|
||||
s.Nil(err)
|
||||
|
||||
defer resp.Body.Close()
|
||||
got, err := ioutil.ReadAll(resp.Body)
|
||||
s.Nil(err)
|
||||
|
||||
s.Equal(bytes.Compare(fileContent, got), 0, fmt.Sprintf("Got %s", got))
|
||||
}
|
81
gophish.go
81
gophish.go
|
@ -26,24 +26,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|||
THE SOFTWARE.
|
||||
*/
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"os/signal"
|
||||
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/gophish/gophish/auth"
|
||||
"github.com/gophish/gophish/config"
|
||||
"github.com/gophish/gophish/controllers"
|
||||
log "github.com/gophish/gophish/logger"
|
||||
"github.com/gophish/gophish/mailer"
|
||||
"github.com/gophish/gophish/models"
|
||||
"github.com/gophish/gophish/util"
|
||||
"github.com/gorilla/handlers"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -65,31 +58,25 @@ func main() {
|
|||
kingpin.Parse()
|
||||
|
||||
// Load the config
|
||||
err = config.LoadConfig(*configPath)
|
||||
conf, err := config.LoadConfig(*configPath)
|
||||
// Just warn if a contact address hasn't been configured
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if config.Conf.ContactAddress == "" {
|
||||
if conf.ContactAddress == "" {
|
||||
log.Warnf("No contact address has been configured.")
|
||||
log.Warnf("Please consider adding a contact_address entry in your config.json")
|
||||
}
|
||||
config.Version = string(version)
|
||||
|
||||
err = log.Setup()
|
||||
err = log.Setup(conf)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Provide the option to disable the built-in mailer
|
||||
if !*disableMailer {
|
||||
go mailer.Mailer.Start(ctx)
|
||||
}
|
||||
// Setup the global variables and settings
|
||||
err = models.Setup()
|
||||
err = models.Setup(conf)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -99,39 +86,27 @@ func main() {
|
|||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
// Start the web servers
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
gzipWrapper, _ := gziphandler.NewGzipLevelHandler(gzip.BestCompression)
|
||||
adminHandler := gzipWrapper(controllers.CreateAdminRouter())
|
||||
auth.Store.Options.Secure = config.Conf.AdminConf.UseTLS
|
||||
if config.Conf.AdminConf.UseTLS { // use TLS for Admin web server if available
|
||||
err := util.CheckAndCreateSSL(config.Conf.AdminConf.CertPath, config.Conf.AdminConf.KeyPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Infof("Starting admin server at https://%s", config.Conf.AdminConf.ListenURL)
|
||||
log.Info(http.ListenAndServeTLS(config.Conf.AdminConf.ListenURL, config.Conf.AdminConf.CertPath, config.Conf.AdminConf.KeyPath,
|
||||
handlers.CombinedLoggingHandler(log.Writer(), adminHandler)))
|
||||
} else {
|
||||
log.Infof("Starting admin server at http://%s", config.Conf.AdminConf.ListenURL)
|
||||
log.Info(http.ListenAndServe(config.Conf.AdminConf.ListenURL, handlers.CombinedLoggingHandler(os.Stdout, adminHandler)))
|
||||
}
|
||||
}()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
phishHandler := gziphandler.GzipHandler(controllers.CreatePhishingRouter())
|
||||
if config.Conf.PhishConf.UseTLS { // use TLS for Phish web server if available
|
||||
log.Infof("Starting phishing server at https://%s", config.Conf.PhishConf.ListenURL)
|
||||
log.Info(http.ListenAndServeTLS(config.Conf.PhishConf.ListenURL, config.Conf.PhishConf.CertPath, config.Conf.PhishConf.KeyPath,
|
||||
handlers.CombinedLoggingHandler(log.Writer(), phishHandler)))
|
||||
} else {
|
||||
log.Infof("Starting phishing server at http://%s", config.Conf.PhishConf.ListenURL)
|
||||
log.Fatal(http.ListenAndServe(config.Conf.PhishConf.ListenURL, handlers.CombinedLoggingHandler(os.Stdout, phishHandler)))
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
// Create our servers
|
||||
adminOptions := []controllers.AdminServerOption{}
|
||||
if *disableMailer {
|
||||
adminOptions = append(adminOptions, controllers.WithWorker(nil))
|
||||
}
|
||||
adminConfig := conf.AdminConf
|
||||
adminServer := controllers.NewAdminServer(adminConfig, adminOptions...)
|
||||
auth.Store.Options.Secure = adminConfig.UseTLS
|
||||
|
||||
phishConfig := conf.PhishConf
|
||||
phishServer := controllers.NewPhishingServer(phishConfig)
|
||||
|
||||
go adminServer.Start()
|
||||
go phishServer.Start()
|
||||
|
||||
// Handle graceful shutdown
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt)
|
||||
<-c
|
||||
log.Info("CTRL+C Received... Gracefully shutting down servers")
|
||||
adminServer.Shutdown()
|
||||
phishServer.Shutdown()
|
||||
}
|
||||
|
|
|
@ -18,10 +18,10 @@ func init() {
|
|||
}
|
||||
|
||||
// Setup configures the logger based on options in the config.json.
|
||||
func Setup() error {
|
||||
func Setup(conf *config.Config) error {
|
||||
Logger.SetLevel(logrus.InfoLevel)
|
||||
// Set up logging to a file if specified in the config
|
||||
logFile := config.Conf.Logging.Filename
|
||||
logFile := conf.Logging.Filename
|
||||
if logFile != "" {
|
||||
f, err := os.OpenFile(logFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644)
|
||||
if err != nil {
|
||||
|
|
|
@ -29,6 +29,13 @@ func (e *ErrMaxConnectAttempts) Error() string {
|
|||
return errString
|
||||
}
|
||||
|
||||
// Mailer is an interface that defines an object used to queue and
|
||||
// send mailer.Mail instances.
|
||||
type Mailer interface {
|
||||
Start(ctx context.Context)
|
||||
Queue([]Mail)
|
||||
}
|
||||
|
||||
// Sender exposes the common operations required for sending email.
|
||||
type Sender interface {
|
||||
Send(from string, to []string, msg io.WriterTo) error
|
||||
|
@ -50,27 +57,18 @@ type Mail interface {
|
|||
GetDialer() (Dialer, error)
|
||||
}
|
||||
|
||||
// Mailer is a global instance of the mailer that can
|
||||
// be used in applications. It is the responsibility of the application
|
||||
// to call Mailer.Start()
|
||||
var Mailer *MailWorker
|
||||
|
||||
func init() {
|
||||
Mailer = NewMailWorker()
|
||||
}
|
||||
|
||||
// MailWorker is the worker that receives slices of emails
|
||||
// on a channel to send. It's assumed that every slice of emails received is meant
|
||||
// to be sent to the same server.
|
||||
type MailWorker struct {
|
||||
Queue chan []Mail
|
||||
queue chan []Mail
|
||||
}
|
||||
|
||||
// NewMailWorker returns an instance of MailWorker with the mail queue
|
||||
// initialized.
|
||||
func NewMailWorker() *MailWorker {
|
||||
return &MailWorker{
|
||||
Queue: make(chan []Mail),
|
||||
queue: make(chan []Mail),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -81,7 +79,7 @@ func (mw *MailWorker) Start(ctx context.Context) {
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case ms := <-mw.Queue:
|
||||
case ms := <-mw.queue:
|
||||
go func(ctx context.Context, ms []Mail) {
|
||||
dialer, err := ms[0].GetDialer()
|
||||
if err != nil {
|
||||
|
@ -94,6 +92,11 @@ func (mw *MailWorker) Start(ctx context.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
// Queue sends the provided mail to the internal queue for processing.
|
||||
func (mw *MailWorker) Queue(ms []Mail) {
|
||||
mw.queue <- ms
|
||||
}
|
||||
|
||||
// errorMail is a helper to handle erroring out a slice of Mail instances
|
||||
// in the case that an unrecoverable error occurs.
|
||||
func errorMail(err error, ms []Mail) {
|
||||
|
|
|
@ -88,7 +88,7 @@ func (ms *MailerSuite) TestMailWorkerStart() {
|
|||
messages := generateMessages(dialer)
|
||||
|
||||
// Send the campaign
|
||||
mw.Queue <- messages
|
||||
mw.Queue(messages)
|
||||
|
||||
got := []*mockMessage{}
|
||||
|
||||
|
@ -129,7 +129,7 @@ func (ms *MailerSuite) TestBackoff() {
|
|||
messages := generateMessages(dialer)
|
||||
|
||||
// Send the campaign
|
||||
mw.Queue <- messages
|
||||
mw.Queue(messages)
|
||||
|
||||
got := []*mockMessage{}
|
||||
|
||||
|
@ -183,7 +183,7 @@ func (ms *MailerSuite) TestPermError() {
|
|||
messages := generateMessages(dialer)
|
||||
|
||||
// Send the campaign
|
||||
mw.Queue <- messages
|
||||
mw.Queue(messages)
|
||||
|
||||
got := []*mockMessage{}
|
||||
|
||||
|
@ -242,7 +242,7 @@ func (ms *MailerSuite) TestUnknownError() {
|
|||
messages := generateMessages(dialer)
|
||||
|
||||
// Send the campaign
|
||||
mw.Queue <- messages
|
||||
mw.Queue(messages)
|
||||
|
||||
got := []*mockMessage{}
|
||||
|
||||
|
|
|
@ -60,8 +60,10 @@ func GetContext(handler http.Handler) http.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func RequireAPIKey(handler http.Handler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// RequireAPIKey ensures that a valid API key is set as either the api_key GET
|
||||
// parameter, or a Bearer token.
|
||||
func RequireAPIKey(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
if r.Method == "OPTIONS" {
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
|
||||
|
@ -81,18 +83,18 @@ func RequireAPIKey(handler http.Handler) http.HandlerFunc {
|
|||
}
|
||||
}
|
||||
if ak == "" {
|
||||
JSONError(w, 400, "API Key not set")
|
||||
JSONError(w, http.StatusUnauthorized, "API Key not set")
|
||||
return
|
||||
}
|
||||
u, err := models.GetUserByAPIKey(ak)
|
||||
if err != nil {
|
||||
JSONError(w, 400, "Invalid API Key")
|
||||
JSONError(w, http.StatusUnauthorized, "Invalid API Key")
|
||||
return
|
||||
}
|
||||
r = ctx.Set(r, "user_id", u.Id)
|
||||
r = ctx.Set(r, "api_key", ak)
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// RequireLogin is a simple middleware which checks to see if the user is currently logged in.
|
||||
|
@ -104,7 +106,7 @@ func RequireLogin(handler http.Handler) http.HandlerFunc {
|
|||
} else {
|
||||
q := r.URL.Query()
|
||||
q.Set("next", r.URL.Path)
|
||||
http.Redirect(w, r, fmt.Sprintf("/login?%s", q.Encode()), 302)
|
||||
http.Redirect(w, r, fmt.Sprintf("/login?%s", q.Encode()), http.StatusTemporaryRedirect)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -165,7 +165,7 @@ func (c *Campaign) AddEvent(e *Event) error {
|
|||
// an error is returned. Otherwise, the attribute name is set to [Deleted],
|
||||
// indicating the user deleted the attribute (template, smtp, etc.)
|
||||
func (c *Campaign) getDetails() error {
|
||||
err = db.Model(c).Related(&c.Results).Error
|
||||
err := db.Model(c).Related(&c.Results).Error
|
||||
if err != nil {
|
||||
log.Warnf("%s: results not found for campaign", err)
|
||||
return err
|
||||
|
@ -402,7 +402,8 @@ func GetQueuedCampaigns(t time.Time) ([]Campaign, error) {
|
|||
|
||||
// PostCampaign inserts a campaign and all associated records into the database.
|
||||
func PostCampaign(c *Campaign, uid int64) error {
|
||||
if err := c.Validate(); err != nil {
|
||||
err := c.Validate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Fill in the details
|
||||
|
@ -514,14 +515,16 @@ func PostCampaign(c *Campaign, uid int64) error {
|
|||
Reported: false,
|
||||
ModifiedDate: c.CreatedDate,
|
||||
}
|
||||
if r.SendDate.Before(c.CreatedDate) || r.SendDate.Equal(c.CreatedDate) {
|
||||
r.Status = STATUS_SENDING
|
||||
}
|
||||
err = r.GenerateId()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
continue
|
||||
}
|
||||
processing := false
|
||||
if r.SendDate.Before(c.CreatedDate) || r.SendDate.Equal(c.CreatedDate) {
|
||||
r.Status = STATUS_SENDING
|
||||
processing = true
|
||||
}
|
||||
err = db.Save(r).Error
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
|
@ -530,7 +533,14 @@ func PostCampaign(c *Campaign, uid int64) error {
|
|||
}
|
||||
c.Results = append(c.Results, *r)
|
||||
log.Infof("Creating maillog for %s to send at %s\n", r.Email, sendDate)
|
||||
err = GenerateMailLog(c, r, sendDate)
|
||||
m := &MailLog{
|
||||
UserId: c.UserId,
|
||||
CampaignId: c.Id,
|
||||
RId: r.RId,
|
||||
SendDate: sendDate,
|
||||
Processing: processing,
|
||||
}
|
||||
err = db.Save(m).Error
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
continue
|
||||
|
|
|
@ -79,3 +79,29 @@ func (s *ModelsSuite) TestCampaignDateValidation(c *check.C) {
|
|||
err = campaign.Validate()
|
||||
c.Assert(err, check.Equals, ErrInvalidSendByDate)
|
||||
}
|
||||
|
||||
func (s *ModelsSuite) TestLaunchCampaignMaillogStatus(c *check.C) {
|
||||
// For the first test, ensure that campaigns created with the zero date
|
||||
// (and therefore are set to launch immediately) have maillogs that are
|
||||
// locked to prevent race conditions.
|
||||
campaign := s.createCampaign(c)
|
||||
ms, err := GetMailLogsByCampaign(campaign.Id)
|
||||
c.Assert(err, check.Equals, nil)
|
||||
|
||||
for _, m := range ms {
|
||||
c.Assert(m.Processing, check.Equals, true)
|
||||
}
|
||||
|
||||
// Next, verify that campaigns scheduled in the future do not lock the
|
||||
// maillogs so that they can be picked up by the background worker.
|
||||
campaign = s.createCampaignDependencies(c)
|
||||
campaign.Name = "New Campaign"
|
||||
campaign.LaunchDate = time.Now().Add(1 * time.Hour)
|
||||
c.Assert(PostCampaign(&campaign, campaign.UserId), check.Equals, nil)
|
||||
ms, err = GetMailLogsByCampaign(campaign.Id)
|
||||
c.Assert(err, check.Equals, nil)
|
||||
|
||||
for _, m := range ms {
|
||||
c.Assert(m.Processing, check.Equals, false)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -122,8 +122,8 @@ func (s *EmailRequest) Generate(msg *gomail.Message) error {
|
|||
|
||||
// Add the transparency headers
|
||||
msg.SetHeader("X-Mailer", config.ServerName)
|
||||
if config.Conf.ContactAddress != "" {
|
||||
msg.SetHeader("X-Gophish-Contact", config.Conf.ContactAddress)
|
||||
if conf.ContactAddress != "" {
|
||||
msg.SetHeader("X-Gophish-Contact", conf.ContactAddress)
|
||||
}
|
||||
|
||||
// Parse the customHeader templates
|
||||
|
|
|
@ -26,7 +26,7 @@ func (s *ModelsSuite) TestEmailRequestBackoff(ch *check.C) {
|
|||
}
|
||||
expected := errors.New("Temporary Error")
|
||||
go func() {
|
||||
err = req.Backoff(expected)
|
||||
err := req.Backoff(expected)
|
||||
ch.Assert(err, check.Equals, nil)
|
||||
}()
|
||||
ch.Assert(<-req.ErrorChan, check.Equals, expected)
|
||||
|
@ -38,7 +38,7 @@ func (s *ModelsSuite) TestEmailRequestError(ch *check.C) {
|
|||
}
|
||||
expected := errors.New("Temporary Error")
|
||||
go func() {
|
||||
err = req.Error(expected)
|
||||
err := req.Error(expected)
|
||||
ch.Assert(err, check.Equals, nil)
|
||||
}()
|
||||
ch.Assert(<-req.ErrorChan, check.Equals, expected)
|
||||
|
@ -49,7 +49,7 @@ func (s *ModelsSuite) TestEmailRequestSuccess(ch *check.C) {
|
|||
ErrorChan: make(chan error),
|
||||
}
|
||||
go func() {
|
||||
err = req.Success()
|
||||
err := req.Success()
|
||||
ch.Assert(err, check.Equals, nil)
|
||||
}()
|
||||
ch.Assert(<-req.ErrorChan, check.Equals, nil)
|
||||
|
@ -76,14 +76,14 @@ func (s *ModelsSuite) TestEmailRequestGenerate(ch *check.C) {
|
|||
FromAddress: smtp.FromAddress,
|
||||
}
|
||||
|
||||
config.Conf.ContactAddress = "test@test.com"
|
||||
s.config.ContactAddress = "test@test.com"
|
||||
expectedHeaders := map[string]string{
|
||||
"X-Mailer": config.ServerName,
|
||||
"X-Gophish-Contact": config.Conf.ContactAddress,
|
||||
"X-Gophish-Contact": s.config.ContactAddress,
|
||||
}
|
||||
|
||||
msg := gomail.NewMessage()
|
||||
err = req.Generate(msg)
|
||||
err := req.Generate(msg)
|
||||
ch.Assert(err, check.Equals, nil)
|
||||
|
||||
expected := &email.Email{
|
||||
|
@ -130,7 +130,7 @@ func (s *ModelsSuite) TestEmailRequestURLTemplating(ch *check.C) {
|
|||
}
|
||||
|
||||
msg := gomail.NewMessage()
|
||||
err = req.Generate(msg)
|
||||
err := req.Generate(msg)
|
||||
ch.Assert(err, check.Equals, nil)
|
||||
|
||||
expectedURL := fmt.Sprintf("http://127.0.0.1/%s?%s=%s", req.Email, RecipientParameter, req.RId)
|
||||
|
@ -167,7 +167,7 @@ func (s *ModelsSuite) TestEmailRequestGenerateEmptySubject(ch *check.C) {
|
|||
}
|
||||
|
||||
msg := gomail.NewMessage()
|
||||
err = req.Generate(msg)
|
||||
err := req.Generate(msg)
|
||||
ch.Assert(err, check.Equals, nil)
|
||||
|
||||
expected := &email.Email{
|
||||
|
|
|
@ -196,7 +196,7 @@ func PostGroup(g *Group) error {
|
|||
return err
|
||||
}
|
||||
// Insert the group into the DB
|
||||
err = db.Save(g).Error
|
||||
err := db.Save(g).Error
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
|
@ -214,7 +214,7 @@ func PutGroup(g *Group) error {
|
|||
}
|
||||
// Fetch group's existing targets from database.
|
||||
ts := []Target{}
|
||||
ts, err = GetTargets(g.Id)
|
||||
ts, err := GetTargets(g.Id)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"group_id": g.Id,
|
||||
|
@ -234,7 +234,7 @@ func PutGroup(g *Group) error {
|
|||
}
|
||||
// If the target does not exist in the group any longer, we delete it
|
||||
if !tExists {
|
||||
err = db.Where("group_id=? and target_id=?", g.Id, t.Id).Delete(&GroupTarget{}).Error
|
||||
err := db.Where("group_id=? and target_id=?", g.Id, t.Id).Delete(&GroupTarget{}).Error
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"email": t.Email,
|
||||
|
@ -286,14 +286,14 @@ func DeleteGroup(g *Group) error {
|
|||
}
|
||||
|
||||
func insertTargetIntoGroup(t Target, gid int64) error {
|
||||
if _, err = mail.ParseAddress(t.Email); err != nil {
|
||||
if _, err := mail.ParseAddress(t.Email); err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"email": t.Email,
|
||||
}).Error("Invalid email")
|
||||
return err
|
||||
}
|
||||
trans := db.Begin()
|
||||
err = trans.Where(t).FirstOrCreate(&t).Error
|
||||
err := trans.Where(t).FirstOrCreate(&t).Error
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"email": t.Email,
|
||||
|
|
|
@ -45,8 +45,7 @@ func GenerateMailLog(c *Campaign, r *Result, sendDate time.Time) error {
|
|||
RId: r.RId,
|
||||
SendDate: sendDate,
|
||||
}
|
||||
err = db.Save(m).Error
|
||||
return err
|
||||
return db.Save(m).Error
|
||||
}
|
||||
|
||||
// Backoff sets the MailLog SendDate to be the next entry in an exponential
|
||||
|
@ -160,8 +159,8 @@ func (m *MailLog) Generate(msg *gomail.Message) error {
|
|||
|
||||
// Add the transparency headers
|
||||
msg.SetHeader("X-Mailer", config.ServerName)
|
||||
if config.Conf.ContactAddress != "" {
|
||||
msg.SetHeader("X-Gophish-Contact", config.Conf.ContactAddress)
|
||||
if conf.ContactAddress != "" {
|
||||
msg.SetHeader("X-Gophish-Contact", conf.ContactAddress)
|
||||
}
|
||||
// Parse the customHeader templates
|
||||
for _, header := range c.SMTP.Headers {
|
||||
|
@ -260,6 +259,5 @@ func LockMailLogs(ms []*MailLog, lock bool) error {
|
|||
// in the database. This is intended to be called when Gophish is started
|
||||
// so that any previously locked maillogs can resume processing.
|
||||
func UnlockAllMailLogs() error {
|
||||
err = db.Model(&MailLog{}).Update("processing", false).Error
|
||||
return err
|
||||
return db.Model(&MailLog{}).Update("processing", false).Error
|
||||
}
|
||||
|
|
|
@ -37,7 +37,13 @@ func (s *ModelsSuite) emailFromFirstMailLog(campaign Campaign, ch *check.C) *ema
|
|||
|
||||
func (s *ModelsSuite) TestGetQueuedMailLogs(ch *check.C) {
|
||||
campaign := s.createCampaign(ch)
|
||||
ms, err := GetQueuedMailLogs(campaign.LaunchDate)
|
||||
// By default, for campaigns with no launch date, the maillogs are set as
|
||||
// being processed. We need to unlock them first.
|
||||
ms, err := GetMailLogsByCampaign(campaign.Id)
|
||||
ch.Assert(err, check.Equals, nil)
|
||||
err = LockMailLogs(ms, false)
|
||||
ch.Assert(err, check.Equals, nil)
|
||||
ms, err = GetQueuedMailLogs(campaign.LaunchDate)
|
||||
ch.Assert(err, check.Equals, nil)
|
||||
got := make(map[string]*MailLog)
|
||||
for _, m := range ms {
|
||||
|
@ -222,10 +228,10 @@ func (s *ModelsSuite) TestMailLogGenerate(ch *check.C) {
|
|||
}
|
||||
|
||||
func (s *ModelsSuite) TestMailLogGenerateTransparencyHeaders(ch *check.C) {
|
||||
config.Conf.ContactAddress = "test@test.com"
|
||||
s.config.ContactAddress = "test@test.com"
|
||||
expectedHeaders := map[string]string{
|
||||
"X-Mailer": config.ServerName,
|
||||
"X-Gophish-Contact": config.Conf.ContactAddress,
|
||||
"X-Gophish-Contact": s.config.ContactAddress,
|
||||
}
|
||||
campaign := s.createCampaign(ch)
|
||||
got := s.emailFromFirstMailLog(campaign, ch)
|
||||
|
@ -264,12 +270,6 @@ func (s *ModelsSuite) TestUnlockAllMailLogs(ch *check.C) {
|
|||
campaign := s.createCampaign(ch)
|
||||
ms, err := GetMailLogsByCampaign(campaign.Id)
|
||||
ch.Assert(err, check.Equals, nil)
|
||||
for _, m := range ms {
|
||||
ch.Assert(m.Processing, check.Equals, false)
|
||||
}
|
||||
err = LockMailLogs(ms, true)
|
||||
ms, err = GetMailLogsByCampaign(campaign.Id)
|
||||
ch.Assert(err, check.Equals, nil)
|
||||
for _, m := range ms {
|
||||
ch.Assert(m.Processing, check.Equals, true)
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
)
|
||||
|
||||
var db *gorm.DB
|
||||
var err error
|
||||
var conf *config.Config
|
||||
|
||||
const (
|
||||
CAMPAIGN_IN_PROGRESS string = "In progress"
|
||||
|
@ -78,12 +78,14 @@ func chooseDBDriver(name, openStr string) goose.DBDriver {
|
|||
|
||||
// Setup initializes the Conn object
|
||||
// It also populates the Gophish Config object
|
||||
func Setup() error {
|
||||
func Setup(c *config.Config) error {
|
||||
// Setup the package-scoped config
|
||||
conf = c
|
||||
// Setup the goose configuration
|
||||
migrateConf := &goose.DBConf{
|
||||
MigrationsDir: config.Conf.MigrationsPath,
|
||||
MigrationsDir: conf.MigrationsPath,
|
||||
Env: "production",
|
||||
Driver: chooseDBDriver(config.Conf.DBName, config.Conf.DBPath),
|
||||
Driver: chooseDBDriver(conf.DBName, conf.DBPath),
|
||||
}
|
||||
// Get the latest possible migration
|
||||
latest, err := goose.GetMostRecentDBVersion(migrateConf.MigrationsDir)
|
||||
|
@ -92,7 +94,7 @@ func Setup() error {
|
|||
return err
|
||||
}
|
||||
// Open our database connection
|
||||
db, err = gorm.Open(config.Conf.DBName, config.Conf.DBPath)
|
||||
db, err = gorm.Open(conf.DBName, conf.DBPath)
|
||||
db.LogMode(false)
|
||||
db.SetLogger(log.Logger)
|
||||
db.DB().SetMaxOpenConns(1)
|
||||
|
|
|
@ -10,15 +10,20 @@ import (
|
|||
// Hook up gocheck into the "go test" runner.
|
||||
func Test(t *testing.T) { check.TestingT(t) }
|
||||
|
||||
type ModelsSuite struct{}
|
||||
type ModelsSuite struct {
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
var _ = check.Suite(&ModelsSuite{})
|
||||
|
||||
func (s *ModelsSuite) SetUpSuite(c *check.C) {
|
||||
config.Conf.DBName = "sqlite3"
|
||||
config.Conf.DBPath = ":memory:"
|
||||
config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/"
|
||||
err := Setup()
|
||||
conf := &config.Config{
|
||||
DBName: "sqlite3",
|
||||
DBPath: ":memory:",
|
||||
MigrationsPath: "../db/db_sqlite3/migrations/",
|
||||
}
|
||||
s.config = conf
|
||||
err := Setup(conf)
|
||||
if err != nil {
|
||||
c.Fatalf("Failed creating database: %v", err)
|
||||
}
|
||||
|
|
|
@ -148,7 +148,7 @@ func PutPage(p *Page) error {
|
|||
// DeletePage deletes an existing page in the database.
|
||||
// An error is returned if a page with the given user id and page id is not found.
|
||||
func DeletePage(id int64, uid int64) error {
|
||||
err = db.Where("user_id=?", uid).Delete(Page{Id: id}).Error
|
||||
err := db.Where("user_id=?", uid).Delete(Page{Id: id}).Error
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
|
|
@ -228,7 +228,7 @@ func PutSMTP(s *SMTP) error {
|
|||
// An error is returned if a SMTP with the given user id and SMTP id is not found.
|
||||
func DeleteSMTP(id int64, uid int64) error {
|
||||
// Delete all custom headers
|
||||
err = db.Where("smtp_id=?", id).Delete(&Header{}).Error
|
||||
err := db.Where("smtp_id=?", id).Delete(&Header{}).Error
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
|
|
|
@ -15,7 +15,7 @@ func (s *ModelsSuite) TestPostSMTP(c *check.C) {
|
|||
FromAddress: "Foo Bar <foo@example.com>",
|
||||
UserId: 1,
|
||||
}
|
||||
err = PostSMTP(&smtp)
|
||||
err := PostSMTP(&smtp)
|
||||
c.Assert(err, check.Equals, nil)
|
||||
ss, err := GetSMTPs(1)
|
||||
c.Assert(err, check.Equals, nil)
|
||||
|
@ -28,7 +28,7 @@ func (s *ModelsSuite) TestPostSMTPNoHost(c *check.C) {
|
|||
FromAddress: "Foo Bar <foo@example.com>",
|
||||
UserId: 1,
|
||||
}
|
||||
err = PostSMTP(&smtp)
|
||||
err := PostSMTP(&smtp)
|
||||
c.Assert(err, check.Equals, ErrHostNotSpecified)
|
||||
}
|
||||
|
||||
|
@ -38,7 +38,7 @@ func (s *ModelsSuite) TestPostSMTPNoFrom(c *check.C) {
|
|||
UserId: 1,
|
||||
Host: "1.1.1.1:25",
|
||||
}
|
||||
err = PostSMTP(&smtp)
|
||||
err := PostSMTP(&smtp)
|
||||
c.Assert(err, check.Equals, ErrFromAddressNotSpecified)
|
||||
}
|
||||
|
||||
|
@ -53,7 +53,7 @@ func (s *ModelsSuite) TestPostSMTPValidHeader(c *check.C) {
|
|||
Header{Key: "X-Mailer", Value: "gophish"},
|
||||
},
|
||||
}
|
||||
err = PostSMTP(&smtp)
|
||||
err := PostSMTP(&smtp)
|
||||
c.Assert(err, check.Equals, nil)
|
||||
ss, err := GetSMTPs(1)
|
||||
c.Assert(err, check.Equals, nil)
|
||||
|
|
|
@ -34,10 +34,10 @@ func (t *Template) Validate() error {
|
|||
case t.Text == "" && t.HTML == "":
|
||||
return ErrTemplateMissingParameter
|
||||
}
|
||||
if err = ValidateTemplate(t.HTML); err != nil {
|
||||
if err := ValidateTemplate(t.HTML); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = ValidateTemplate(t.Text); err != nil {
|
||||
if err := ValidateTemplate(t.Text); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
@ -113,7 +113,7 @@ func PostTemplate(t *Template) error {
|
|||
if err := t.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
err = db.Save(t).Error
|
||||
err := db.Save(t).Error
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return err
|
||||
|
@ -138,7 +138,7 @@ func PutTemplate(t *Template) error {
|
|||
return err
|
||||
}
|
||||
// Delete all attachments, and replace with new ones
|
||||
err = db.Where("template_id=?", t.Id).Delete(&Attachment{}).Error
|
||||
err := db.Where("template_id=?", t.Id).Delete(&Attachment{}).Error
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
log.Error(err)
|
||||
return err
|
||||
|
@ -146,7 +146,7 @@ func PutTemplate(t *Template) error {
|
|||
if err == gorm.ErrRecordNotFound {
|
||||
err = nil
|
||||
}
|
||||
for i, _ := range t.Attachments {
|
||||
for i := range t.Attachments {
|
||||
t.Attachments[i].TemplateId = t.Id
|
||||
err := db.Save(&t.Attachments[i]).Error
|
||||
if err != nil {
|
||||
|
|
|
@ -18,10 +18,12 @@ type UtilSuite struct {
|
|||
}
|
||||
|
||||
func (s *UtilSuite) SetupSuite() {
|
||||
config.Conf.DBName = "sqlite3"
|
||||
config.Conf.DBPath = ":memory:"
|
||||
config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/"
|
||||
err := models.Setup()
|
||||
conf := &config.Config{
|
||||
DBName: "sqlite3",
|
||||
DBPath: ":memory:",
|
||||
MigrationsPath: "../db/db_sqlite3/migrations/",
|
||||
}
|
||||
err := models.Setup(conf)
|
||||
if err != nil {
|
||||
s.T().Fatalf("Failed creating database: %v", err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/gophish/gophish/logger"
|
||||
|
@ -9,18 +10,46 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Worker is the background worker that handles watching for new campaigns and sending emails appropriately.
|
||||
type Worker struct{}
|
||||
// Worker is an interface that defines the operations needed for a background worker
|
||||
type Worker interface {
|
||||
Start()
|
||||
LaunchCampaign(c models.Campaign)
|
||||
SendTestEmail(s *models.EmailRequest) error
|
||||
}
|
||||
|
||||
// DefaultWorker is the background worker that handles watching for new campaigns and sending emails appropriately.
|
||||
type DefaultWorker struct {
|
||||
mailer mailer.Mailer
|
||||
}
|
||||
|
||||
// New creates a new worker object to handle the creation of campaigns
|
||||
func New() *Worker {
|
||||
return &Worker{}
|
||||
func New(options ...func(Worker) error) (Worker, error) {
|
||||
defaultMailer := mailer.NewMailWorker()
|
||||
w := &DefaultWorker{
|
||||
mailer: defaultMailer,
|
||||
}
|
||||
for _, opt := range options {
|
||||
if err := opt(w); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// WithMailer sets the mailer for a given worker.
|
||||
// By default, workers use a standard, default mailworker.
|
||||
func WithMailer(m mailer.Mailer) func(*DefaultWorker) error {
|
||||
return func(w *DefaultWorker) error {
|
||||
w.mailer = m
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Start launches the worker to poll the database every minute for any pending maillogs
|
||||
// that need to be processed.
|
||||
func (w *Worker) Start() {
|
||||
func (w *DefaultWorker) Start() {
|
||||
log.Info("Background Worker Started Successfully - Waiting for Campaigns")
|
||||
go w.mailer.Start(context.Background())
|
||||
for t := range time.Tick(1 * time.Minute) {
|
||||
ms, err := models.GetQueuedMailLogs(t.UTC())
|
||||
if err != nil {
|
||||
|
@ -62,14 +91,14 @@ func (w *Worker) Start() {
|
|||
log.WithFields(logrus.Fields{
|
||||
"num_emails": len(msc),
|
||||
}).Info("Sending emails to mailer for processing")
|
||||
mailer.Mailer.Queue <- msc
|
||||
w.mailer.Queue(msc)
|
||||
}(cid, msc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LaunchCampaign starts a campaign
|
||||
func (w *Worker) LaunchCampaign(c models.Campaign) {
|
||||
func (w *DefaultWorker) LaunchCampaign(c models.Campaign) {
|
||||
ms, err := models.GetMailLogsByCampaign(c.Id)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
|
@ -89,14 +118,14 @@ func (w *Worker) LaunchCampaign(c models.Campaign) {
|
|||
}
|
||||
mailEntries = append(mailEntries, m)
|
||||
}
|
||||
mailer.Mailer.Queue <- mailEntries
|
||||
w.mailer.Queue(mailEntries)
|
||||
}
|
||||
|
||||
// SendTestEmail sends a test email
|
||||
func (w *Worker) SendTestEmail(s *models.EmailRequest) error {
|
||||
func (w *DefaultWorker) SendTestEmail(s *models.EmailRequest) error {
|
||||
go func() {
|
||||
ms := []mailer.Mail{s}
|
||||
mailer.Mailer.Queue <- ms
|
||||
w.mailer.Queue(ms)
|
||||
}()
|
||||
return <-s.ErrorChan
|
||||
}
|
||||
|
|
|
@ -9,17 +9,20 @@ import (
|
|||
// WorkerSuite is a suite of tests to cover API related functions
|
||||
type WorkerSuite struct {
|
||||
suite.Suite
|
||||
ApiKey string
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func (s *WorkerSuite) SetupSuite() {
|
||||
config.Conf.DBName = "sqlite3"
|
||||
config.Conf.DBPath = ":memory:"
|
||||
config.Conf.MigrationsPath = "../db/db_sqlite3/migrations/"
|
||||
err := models.Setup()
|
||||
conf := &config.Config{
|
||||
DBName: "sqlite3",
|
||||
DBPath: ":memory:",
|
||||
MigrationsPath: "../db/db_sqlite3/migrations/",
|
||||
}
|
||||
err := models.Setup(conf)
|
||||
if err != nil {
|
||||
s.T().Fatalf("Failed creating database: %v", err)
|
||||
}
|
||||
s.config = conf
|
||||
s.Nil(err)
|
||||
}
|
||||
|
||||
|
@ -31,7 +34,7 @@ func (s *WorkerSuite) TearDownTest() {
|
|||
}
|
||||
|
||||
func (s *WorkerSuite) SetupTest() {
|
||||
config.Conf.TestFlag = true
|
||||
s.config.TestFlag = true
|
||||
// Add a group
|
||||
group := models.Group{Name: "Test Group"}
|
||||
group.Targets = []models.Target{
|
||||
|
@ -73,7 +76,3 @@ func (s *WorkerSuite) SetupTest() {
|
|||
models.PostCampaign(&c, c.UserId)
|
||||
c.UpdateStatus(models.CAMPAIGN_EMAILS_SENT)
|
||||
}
|
||||
|
||||
func (s *WorkerSuite) TestMailSendSuccess() {
|
||||
// TODO
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue