Added support for --setup flag to reset database

pull/24/head
Jordan 2014-01-12 22:39:40 -06:00
parent c42ddf3dd7
commit 6944854005
4 changed files with 30 additions and 27 deletions

View File

@ -53,6 +53,8 @@ func Login(r *http.Request) (bool, error) {
return true, nil return true, nil
} }
// GetUserById returns the user that the given id corresponds to. If no user is found, an
// error is thrown.
func GetUserById(id int) (models.User, error) { func GetUserById(id int) (models.User, error) {
u := models.User{} u := models.User{}
stmt, err := db.Conn.Prepare("SELECT id, username, apikey FROM Users WHERE id=?") stmt, err := db.Conn.Prepare("SELECT id, username, apikey FROM Users WHERE id=?")
@ -61,12 +63,13 @@ func GetUserById(id int) (models.User, error) {
} }
err = stmt.QueryRow(id).Scan(&u.Id, &u.Username, &u.APIKey) err = stmt.QueryRow(id).Scan(&u.Id, &u.Username, &u.APIKey)
if err != nil { if err != nil {
//Return false, but don't return an error
return u, err return u, err
} }
return u, nil return u, nil
} }
// GetUserByAPIKey returns the user that the given API Key corresponds to. If no user is found, an
// error is thrown.
func GetUserByAPIKey(key []byte) (models.User, error) { func GetUserByAPIKey(key []byte) (models.User, error) {
u := models.User{} u := models.User{}
stmt, err := db.Conn.Prepare("SELECT id, username, apikey FROM Users WHERE apikey=?") stmt, err := db.Conn.Prepare("SELECT id, username, apikey FROM Users WHERE apikey=?")
@ -75,7 +78,6 @@ func GetUserByAPIKey(key []byte) (models.User, error) {
} }
err = stmt.QueryRow(key).Scan(&u.Id, &u.Username, &u.APIKey) err = stmt.QueryRow(key).Scan(&u.Id, &u.Username, &u.APIKey)
if err != nil { if err != nil {
//Return false, but don't return an error
return u, err return u, err
} }
return u, nil return u, nil

View File

@ -83,8 +83,7 @@ func Base(w http.ResponseWriter, r *http.Request) {
User models.User User models.User
Title string Title string
Flashes []interface{} Flashes []interface{}
}{} }{Title: "Dashboard", User: ctx.Get(r, "user").(models.User)}
params.User = ctx.Get(r, "user").(models.User)
fmt.Println(params.User.Username) fmt.Println(params.User.Username)
getTemplate(w, "dashboard").ExecuteTemplate(w, "base", nil) getTemplate(w, "dashboard").ExecuteTemplate(w, "base", nil)
} }
@ -112,9 +111,8 @@ func Login(w http.ResponseWriter, r *http.Request) {
User models.User User models.User
Title string Title string
Flashes []interface{} Flashes []interface{}
}{} }{Title: "Login"}
session := ctx.Get(r, "session").(*sessions.Session) session := ctx.Get(r, "session").(*sessions.Session)
params.Title = "Login"
switch { switch {
case r.Method == "GET": case r.Method == "GET":
getTemplate(w, "login").ExecuteTemplate(w, "base", params) getTemplate(w, "login").ExecuteTemplate(w, "base", params)

View File

@ -13,38 +13,40 @@ var Conn *sql.DB
// Setup initializes the Conn object // Setup initializes the Conn object
// It also populates the Gophish Config object // It also populates the Gophish Config object
func Setup() error { func Setup(reset bool) error {
createTablesSQL := []string{
//Create tables
`CREATE TABLE Users (id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL, hash VARCHAR(60) NOT NULL, apikey VARCHAR(32));`,
`CREATE TABLE Campaigns (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, created_date TEXT NOT NULL, completed_date TEXT, status TEXT NOT NULL);`,
}
//If the file already exists, delete it and recreate it //If the file already exists, delete it and recreate it
_, err := os.Stat(config.Conf.DBPath) _, err := os.Stat(config.Conf.DBPath)
if err == nil { if err == nil {
os.Remove(config.Conf.DBPath) os.Remove(config.Conf.DBPath)
} }
fmt.Println("Creating db at " + config.Conf.DBPath)
Conn, err = sql.Open("sqlite3", config.Conf.DBPath) Conn, err = sql.Open("sqlite3", config.Conf.DBPath)
if err != nil { if err != nil {
return err return err
} }
//Create the tables needed if reset {
for _, stmt := range createTablesSQL { createTablesSQL := []string{
_, err = Conn.Exec(stmt) //Create tables
`CREATE TABLE Users (id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL, hash VARCHAR(60) NOT NULL, apikey VARCHAR(32));`,
`CREATE TABLE Campaigns (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, created_date TEXT NOT NULL, completed_date TEXT, status TEXT NOT NULL);`,
}
fmt.Println("Creating db at " + config.Conf.DBPath)
//Create the tables needed
for _, stmt := range createTablesSQL {
_, err = Conn.Exec(stmt)
if err != nil {
return err
}
}
//Create the default user
stmt, err := Conn.Prepare(`INSERT INTO Users (username, hash, apikey) VALUES (?, ?, ?);`)
defer stmt.Close()
if err != nil {
return err
}
_, err = stmt.Exec("jordan", "$2a$10$d4OtT.RkEOQn.iruVWIQ5u8CeV/85ZYF41y8wKeUwsAPqPNFvTccW", "12345678901234567890123456789012")
if err != nil { if err != nil {
return err return err
} }
} }
//Create the default user
stmt, err := Conn.Prepare(`INSERT INTO Users (username, hash, apikey) VALUES (?, ?, ?);`)
defer stmt.Close()
if err != nil {
return err
}
_, err = stmt.Exec("jordan", "$2a$10$d4OtT.RkEOQn.iruVWIQ5u8CeV/85ZYF41y8wKeUwsAPqPNFvTccW", "12345678901234567890123456789012")
if err != nil {
return err
}
return nil return nil
} }

View File

@ -40,7 +40,8 @@ var setupFlag = flag.Bool("setup", false, "Starts the initial setup process for
func main() { func main() {
//Setup the global variables and settings //Setup the global variables and settings
err := db.Setup() flag.Parse()
err := db.Setup(*setupFlag)
defer db.Conn.Close() defer db.Conn.Close()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)