diff --git a/auth/auth.go b/auth/auth.go index 45193fca..33033051 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -53,6 +53,8 @@ func Login(r *http.Request) (bool, error) { return true, nil } +// GetUserById returns the user that the given id corresponds to. If no user is found, an +// error is thrown. func GetUserById(id int) (models.User, error) { u := models.User{} stmt, err := db.Conn.Prepare("SELECT id, username, apikey FROM Users WHERE id=?") @@ -61,12 +63,13 @@ func GetUserById(id int) (models.User, error) { } err = stmt.QueryRow(id).Scan(&u.Id, &u.Username, &u.APIKey) if err != nil { - //Return false, but don't return an error return u, err } return u, nil } +// GetUserByAPIKey returns the user that the given API Key corresponds to. If no user is found, an +// error is thrown. func GetUserByAPIKey(key []byte) (models.User, error) { u := models.User{} stmt, err := db.Conn.Prepare("SELECT id, username, apikey FROM Users WHERE apikey=?") @@ -75,7 +78,6 @@ func GetUserByAPIKey(key []byte) (models.User, error) { } err = stmt.QueryRow(key).Scan(&u.Id, &u.Username, &u.APIKey) if err != nil { - //Return false, but don't return an error return u, err } return u, nil diff --git a/controllers/route.go b/controllers/route.go index e81eebbc..a4770fc5 100644 --- a/controllers/route.go +++ b/controllers/route.go @@ -83,8 +83,7 @@ func Base(w http.ResponseWriter, r *http.Request) { User models.User Title string Flashes []interface{} - }{} - params.User = ctx.Get(r, "user").(models.User) + }{Title: "Dashboard", User: ctx.Get(r, "user").(models.User)} fmt.Println(params.User.Username) getTemplate(w, "dashboard").ExecuteTemplate(w, "base", nil) } @@ -112,9 +111,8 @@ func Login(w http.ResponseWriter, r *http.Request) { User models.User Title string Flashes []interface{} - }{} + }{Title: "Login"} session := ctx.Get(r, "session").(*sessions.Session) - params.Title = "Login" switch { case r.Method == "GET": getTemplate(w, "login").ExecuteTemplate(w, "base", params) diff --git a/db/db.go b/db/db.go index 245b7133..88742cfa 100644 --- a/db/db.go +++ b/db/db.go @@ -13,38 +13,40 @@ var Conn *sql.DB // Setup initializes the Conn object // It also populates the Gophish Config object -func Setup() error { - createTablesSQL := []string{ - //Create tables - `CREATE TABLE Users (id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL, hash VARCHAR(60) NOT NULL, apikey VARCHAR(32));`, - `CREATE TABLE Campaigns (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, created_date TEXT NOT NULL, completed_date TEXT, status TEXT NOT NULL);`, - } +func Setup(reset bool) error { //If the file already exists, delete it and recreate it _, err := os.Stat(config.Conf.DBPath) if err == nil { os.Remove(config.Conf.DBPath) } - fmt.Println("Creating db at " + config.Conf.DBPath) Conn, err = sql.Open("sqlite3", config.Conf.DBPath) if err != nil { return err } - //Create the tables needed - for _, stmt := range createTablesSQL { - _, err = Conn.Exec(stmt) + if reset { + createTablesSQL := []string{ + //Create tables + `CREATE TABLE Users (id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL, hash VARCHAR(60) NOT NULL, apikey VARCHAR(32));`, + `CREATE TABLE Campaigns (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, created_date TEXT NOT NULL, completed_date TEXT, status TEXT NOT NULL);`, + } + fmt.Println("Creating db at " + config.Conf.DBPath) + //Create the tables needed + for _, stmt := range createTablesSQL { + _, err = Conn.Exec(stmt) + if err != nil { + return err + } + } + //Create the default user + stmt, err := Conn.Prepare(`INSERT INTO Users (username, hash, apikey) VALUES (?, ?, ?);`) + defer stmt.Close() + if err != nil { + return err + } + _, err = stmt.Exec("jordan", "$2a$10$d4OtT.RkEOQn.iruVWIQ5u8CeV/85ZYF41y8wKeUwsAPqPNFvTccW", "12345678901234567890123456789012") if err != nil { return err } } - //Create the default user - stmt, err := Conn.Prepare(`INSERT INTO Users (username, hash, apikey) VALUES (?, ?, ?);`) - defer stmt.Close() - if err != nil { - return err - } - _, err = stmt.Exec("jordan", "$2a$10$d4OtT.RkEOQn.iruVWIQ5u8CeV/85ZYF41y8wKeUwsAPqPNFvTccW", "12345678901234567890123456789012") - if err != nil { - return err - } return nil } diff --git a/gophish.go b/gophish.go index 39cbef17..6ca935aa 100644 --- a/gophish.go +++ b/gophish.go @@ -40,7 +40,8 @@ var setupFlag = flag.Bool("setup", false, "Starts the initial setup process for func main() { //Setup the global variables and settings - err := db.Setup() + flag.Parse() + err := db.Setup(*setupFlag) defer db.Conn.Close() if err != nil { fmt.Println(err)