diff --git a/.travis.yml b/.travis.yml index 9a18d00e..cb2cc2e7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,4 +2,8 @@ language: go go: - 1.1 - - tip \ No newline at end of file + - tip + +install: + - go get -d -v ./... && go build -v ./... + - go get launchpad.net/gocheck \ No newline at end of file diff --git a/auth/auth.go b/auth/auth.go index ab268d57..d6409c72 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -33,7 +33,7 @@ func Login(r *http.Request) (bool, error) { username, password := r.FormValue("username"), r.FormValue("password") session, _ := Store.Get(r, "gophish") u, err := models.GetUserByUsername(username) - if err != models.ErrUsernameTaken { + if err != nil && err != models.ErrUsernameTaken { return false, err } //If we've made it here, we should have a valid user stored in u @@ -61,7 +61,7 @@ func Register(r *http.Request) (bool, error) { h, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) u.Username = username u.Hash = string(h) - u.APIKey = GenerateSecureKey() + u.ApiKey = GenerateSecureKey() if err != nil { return false, err } diff --git a/config/config.go b/config/config.go index 3884c9f7..a92e3a41 100644 --- a/config/config.go +++ b/config/config.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "io/ioutil" - "os" ) type SMTPServer struct { @@ -27,7 +26,6 @@ func init() { config_file, err := ioutil.ReadFile("./config.json") if err != nil { fmt.Printf("File error: %v\n", err) - os.Exit(1) } json.Unmarshal(config_file, &Conf) } diff --git a/controllers/api.go b/controllers/api.go index 17a880d0..b535ff17 100644 --- a/controllers/api.go +++ b/controllers/api.go @@ -39,7 +39,7 @@ func API_Reset(w http.ResponseWriter, r *http.Request) { switch { case r.Method == "POST": u := ctx.Get(r, "user").(models.User) - u.APIKey = auth.GenerateSecureKey() + u.ApiKey = auth.GenerateSecureKey() err := models.PutUser(&u) if err != nil { Flash(w, r, "danger", "Error resetting API Key") @@ -80,6 +80,7 @@ func API_Campaigns(w http.ResponseWriter, r *http.Request) { c.CreatedDate = time.Now() c.CompletedDate = time.Time{} c.Status = IN_PROGRESS + c.UserId = ctx.Get(r, "user_id").(int64) err = models.PostCampaign(&c, ctx.Get(r, "user_id").(int64)) if checkError(err, w, "Cannot insert campaign into database", http.StatusInternalServerError) { return @@ -175,7 +176,8 @@ func API_Groups(w http.ResponseWriter, r *http.Request) { return } g.ModifiedDate = time.Now() - err = models.PostGroup(&g, ctx.Get(r, "user_id").(int64)) + g.UserId = ctx.Get(r, "user_id").(int64) + err = models.PostGroup(&g) if checkError(err, w, "Error inserting group", http.StatusInternalServerError) { return } @@ -204,11 +206,11 @@ func API_Groups_Id(w http.ResponseWriter, r *http.Request) { } writeJSON(w, gj) case r.Method == "DELETE": - _, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64)) + g, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64)) if checkError(err, w, "No group found", http.StatusNotFound) { return } - err = models.DeleteGroup(id) + err = models.DeleteGroup(&g) if checkError(err, w, "Error deleting group", http.StatusInternalServerError) { return } @@ -229,7 +231,9 @@ func API_Groups_Id(w http.ResponseWriter, r *http.Request) { http.Error(w, "Error: No targets specified", http.StatusBadRequest) return } - err = models.PutGroup(&g, ctx.Get(r, "user_id").(int64)) + g.ModifiedDate = time.Now() + g.UserId = ctx.Get(r, "user_id").(int64) + err = models.PutGroup(&g) if checkError(err, w, "Error updating group", http.StatusInternalServerError) { return } diff --git a/gophish.go b/gophish.go index 9433e5ef..acc45e62 100644 --- a/gophish.go +++ b/gophish.go @@ -38,7 +38,7 @@ import ( func main() { //Setup the global variables and settings err := models.Setup() - defer models.DB.Close() + //defer models.db.DB().Close() if err != nil { fmt.Println(err) } diff --git a/gophish_test.go b/gophish_test.go deleted file mode 100644 index 97510b9a..00000000 --- a/gophish_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package main - -import ( - "testing" - - "github.com/jordan-wright/gophish/models" -) - -func TestDBSetup(t *testing.T) { - err := models.Setup() - if err != nil { - t.Fatalf("Failed creating database: %v", err) - } -} diff --git a/middleware/middleware.go b/middleware/middleware.go index 0d69d905..797bc3bf 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -52,12 +52,13 @@ func RequireAPIKey(handler http.Handler) http.HandlerFunc { if ak == "" { JSONError(w, 400, "API Key not set") } else { - id, err := models.Conn.SelectInt("SELECT id FROM users WHERE api_key=?", ak) - if id == 0 || err != nil { + u, err := models.GetUserByAPIKey(ak) + /* id, err := models.Conn.SelectInt("SELECT id FROM users WHERE api_key=?", ak) + */if err != nil { JSONError(w, 400, "Invalid API Key") return } - ctx.Set(r, "user_id", id) + ctx.Set(r, "user_id", u.Id) ctx.Set(r, "api_key", ak) handler.ServeHTTP(w, r) } diff --git a/models/campaign.go b/models/campaign.go index 9db54ee0..46d91887 100644 --- a/models/campaign.go +++ b/models/campaign.go @@ -9,38 +9,48 @@ import ( //Campaign is a struct representing a created campaign type Campaign struct { Id int64 `json:"id"` - Name string `json:"name"` - CreatedDate time.Time `json:"created_date" db:"created_date"` - CompletedDate time.Time `json:"completed_date" db:"completed_date"` + UserId int64 `json:"-"` + Name string `json:"name" sql:"not null"` + CreatedDate time.Time `json:"created_date"` + CompletedDate time.Time `json:"completed_date"` Template string `json:"template"` //This may change Status string `json:"status"` - Results []Result `json:"results,omitempty" db:"-"` - Groups []Group `json:"groups,omitempty" db:"-"` + Results []Result `json:"results,omitempty"` + Groups []Group `json:"groups,omitempty"` } type Result struct { - Target - Status string `json:"status"` + Id int64 `json:"-"` + CampaignId int64 `json:"-"` + Email string `json:"email"` + Status string `json:"status" sql:"not null"` } // GetCampaigns returns the campaigns owned by the given user. func GetCampaigns(uid int64) ([]Campaign, error) { cs := []Campaign{} - _, err := Conn.Select(&cs, "SELECT c.id, name, created_date, completed_date, status, template FROM campaigns c, user_campaigns uc, users u WHERE uc.uid=u.id AND uc.cid=c.id AND u.id=?", uid) - for i, _ := range cs { - _, err = Conn.Select(&cs[i].Results, "SELECT r.email, r.status FROM campaign_results r WHERE r.cid=?", cs[i].Id) + err := db.Model(&User{Id: uid}).Related(&cs).Error + if err != nil { + fmt.Println(err) } + for i, _ := range cs { + err := db.Model(&cs[i]).Related(&cs[i].Results).Error + if err != nil { + fmt.Println(err) + } + } + fmt.Printf("%v", cs) return cs, err } // GetCampaign returns the campaign, if it exists, specified by the given id and user_id. func GetCampaign(id int64, uid int64) (Campaign, error) { c := Campaign{} - err := Conn.SelectOne(&c, "SELECT c.id, name, created_date, completed_date, status, template FROM campaigns c, user_campaigns uc, users u WHERE uc.uid=u.id AND uc.cid=c.id AND c.id=? AND u.id=?", id, uid) + err := db.Where("id = ?", id).Where("user_id = ?", uid).Find(&c).Error if err != nil { return c, err } - _, err = Conn.Select(&c.Results, "SELECT r.email, r.status FROM campaign_results r WHERE r.cid=?", c.Id) + err = db.Model(&c).Related(&c.Results).Error return c, err } @@ -58,7 +68,7 @@ func PostCampaign(c *Campaign, uid int64) error { } } // Insert into the DB - err = Conn.Insert(c) + err = db.Save(c).Error if err != nil { Logger.Println(err) return err @@ -67,35 +77,32 @@ func PostCampaign(c *Campaign, uid int64) error { for _, g := range c.Groups { // Insert a result for each target in the group for _, t := range g.Targets { - r := Result{Target: t, Status: "Unknown"} + r := Result{Email: t.Email, Status: "Unknown", CampaignId: c.Id} c.Results = append(c.Results, r) fmt.Printf("%v", c.Results) - _, err = Conn.Exec("INSERT INTO campaign_results VALUES (?,?,?)", c.Id, r.Email, r.Status) + err := db.Save(&r).Error if err != nil { Logger.Printf("Error adding result record for target %s\n", t.Email) Logger.Println(err) } } } - _, err = Conn.Exec("INSERT OR IGNORE INTO user_campaigns VALUES (?,?)", uid, c.Id) - if err != nil { - Logger.Printf("Error adding many-many mapping for campaign %s\n", c.Name) - } return nil } +//DeleteCampaign deletes the specified campaign func DeleteCampaign(id int64) error { - // Delete all the campaign_results entries for this group - _, err := Conn.Exec("DELETE FROM campaign_results WHERE cid=?", id) + // Delete all the campaign results + err := db.Delete(&Result{CampaignId: id}).Error if err != nil { + Logger.Println(err) return err } - // Delete the reference to the campaign in the user_campaigns table - _, err = Conn.Exec("DELETE FROM user_campaigns WHERE cid=?", id) + // Delete the campaign + err = db.Delete(&Campaign{Id: id}).Error if err != nil { + Logger.Panicln(err) return err } - // Delete the campaign itself - _, err = Conn.Exec("DELETE FROM campaigns WHERE id=?", id) return err } diff --git a/models/group.go b/models/group.go index 1305dda4..dc80c501 100644 --- a/models/group.go +++ b/models/group.go @@ -3,13 +3,21 @@ package models import ( "net/mail" "time" + + "github.com/jinzhu/gorm" ) type Group struct { Id int64 `json:"id"` + UserId int64 `json:"-"` Name string `json:"name"` - ModifiedDate time.Time `json:"modified_date" db:"modified_date"` - Targets []Target `json:"targets" db:"-"` + ModifiedDate time.Time `json:"modified_date"` + Targets []Target `json:"targets" sql:"-"` +} + +type GroupTarget struct { + GroupId int64 `json:"-"` + TargetId int64 `json:"-"` } type Target struct { @@ -20,13 +28,13 @@ type Target struct { // GetGroups returns the groups owned by the given user. func GetGroups(uid int64) ([]Group, error) { gs := []Group{} - _, err := Conn.Select(&gs, "SELECT g.id, g.name, g.modified_date FROM groups g, user_groups ug, users u WHERE ug.uid=u.id AND ug.gid=g.id AND u.id=?", uid) + err := db.Where("user_id=?", uid).Find(&gs).Error if err != nil { Logger.Println(err) return gs, err } for i, _ := range gs { - _, err := Conn.Select(&gs[i].Targets, "SELECT t.id, t.email FROM targets t, group_targets gt WHERE gt.gid=? AND gt.tid=t.id", gs[i].Id) + gs[i].Targets, err = GetTargets(gs[i].Id) if err != nil { Logger.Println(err) } @@ -37,12 +45,12 @@ func GetGroups(uid int64) ([]Group, error) { // GetGroup returns the group, if it exists, specified by the given id and user_id. func GetGroup(id int64, uid int64) (Group, error) { g := Group{} - err := Conn.SelectOne(&g, "SELECT g.id, g.name, g.modified_date FROM groups g, user_groups ug, users u WHERE ug.uid=u.id AND ug.gid=g.id AND g.id=? AND u.id=?", id, uid) + err := db.Where("user_id=? and id=?", uid, id).Find(&g).Error if err != nil { Logger.Println(err) return g, err } - _, err = Conn.Select(&g.Targets, "SELECT t.id, t.email FROM targets t, group_targets gt WHERE gt.gid=? AND gt.tid=t.id", g.Id) + g.Targets, err = GetTargets(g.Id) if err != nil { Logger.Println(err) } @@ -52,12 +60,12 @@ func GetGroup(id int64, uid int64) (Group, error) { // GetGroupByName returns the group, if it exists, specified by the given name and user_id. func GetGroupByName(n string, uid int64) (Group, error) { g := Group{} - err := Conn.SelectOne(&g, "SELECT g.id, g.name, g.modified_date FROM groups g, user_groups ug, users u WHERE ug.uid=u.id AND ug.gid=g.id AND g.name=? AND u.id=?", n, uid) + err := db.Where("user_id=? and name=?", uid, n).Find(&g).Error if err != nil { Logger.Println(err) return g, err } - _, err = Conn.Select(&g.Targets, "SELECT t.id, t.email FROM targets t, group_targets gt WHERE gt.gid=? AND gt.tid=t.id", g.Id) + g.Targets, err = GetTargets(g.Id) if err != nil { Logger.Println(err) } @@ -65,18 +73,13 @@ func GetGroupByName(n string, uid int64) (Group, error) { } // PostGroup creates a new group in the database. -func PostGroup(g *Group, uid int64) error { +func PostGroup(g *Group) error { // Insert into the DB - err = Conn.Insert(g) + err = db.Save(g).Error if err != nil { Logger.Println(err) return err } - // Now, let's add the user->user_groups->group mapping - _, err = Conn.Exec("INSERT OR IGNORE INTO user_groups VALUES (?,?)", uid, g.Id) - if err != nil { - Logger.Printf("Error adding many-many mapping for group %s\n", g.Name) - } for _, t := range g.Targets { insertTargetIntoGroup(t, g.Id) } @@ -84,13 +87,9 @@ func PostGroup(g *Group, uid int64) error { } // PutGroup updates the given group if found in the database. -func PutGroup(g *Group, uid int64) error { - // Update all the foreign keys, and many to many relationships - // We will only delete the group->targets entries. We keep the actual targets - // since they are needed by the Results table - // Get all the targets currently in the database for the group +func PutGroup(g *Group) error { ts := []Target{} - _, err = Conn.Select(&ts, "SELECT t.id, t.email FROM targets t, group_targets gt WHERE gt.gid=? AND gt.tid=t.id", g.Id) + ts, err = GetTargets(g.Id) if err != nil { Logger.Printf("Error getting targets from group ID: %d", g.Id) return err @@ -109,7 +108,7 @@ func PutGroup(g *Group, uid int64) error { } // If the target does not exist in the group any longer, we delete it if !tExists { - _, err = Conn.Exec("DELETE FROM group_targets WHERE gid=? AND tid=?", g.Id, t.Id) + err = db.Where("group_id=? and target_id=?", g.Id, t.Id).Delete(&GroupTarget{}).Error if err != nil { Logger.Printf("Error deleting email %s\n", t.Email) } @@ -131,9 +130,8 @@ func PutGroup(g *Group, uid int64) error { insertTargetIntoGroup(nt, g.Id) } } - // Update the group - g.ModifiedDate = time.Now() - _, err = Conn.Update(g) + err = db.Save(g).Error + /*_, err = Conn.Update(g)*/ if err != nil { Logger.Println(err) return err @@ -141,33 +139,47 @@ func PutGroup(g *Group, uid int64) error { return nil } +// DeleteGroup deletes a given group by group ID and user ID +func DeleteGroup(g *Group) error { + // Delete all the group_targets entries for this group + err := db.Where("group_id=?", g.Id).Delete(&GroupTarget{}).Error + if err != nil { + Logger.Println(err) + return err + } + // Delete the group itself + err = db.Delete(g).Error + if err != nil { + Logger.Println(err) + return err + } + return err +} + func insertTargetIntoGroup(t Target, gid int64) error { if _, err = mail.ParseAddress(t.Email); err != nil { Logger.Printf("Invalid email %s\n", t.Email) return err } - trans, err := Conn.Begin() + trans := db.Begin() + trans.Where(t).FirstOrCreate(&t) if err != nil { - Logger.Println(err) + Logger.Printf("Error adding target: %s\n", t.Email) return err } - _, err = trans.Exec("INSERT OR IGNORE INTO targets VALUES (null, ?)", t.Email) - if err != nil { - Logger.Printf("Error adding email: %s\n", t.Email) - return err + err = trans.Where("group_id=? and target_id=?", gid, t.Id).Find(&GroupTarget{}).Error + if err == gorm.RecordNotFound { + err = trans.Save(&GroupTarget{GroupId: gid, TargetId: t.Id}).Error + if err != nil { + Logger.Println(err) + return err + } } - // Bug: res.LastInsertId() does not work for this, so we need to select it manually (how frustrating.) - t.Id, err = trans.SelectInt("SELECT id FROM targets WHERE email=?", t.Email) - if err != nil { - Logger.Printf("Error getting id for email: %s\n", t.Email) - return err - } - _, err = trans.Exec("INSERT OR IGNORE INTO group_targets VALUES (?,?)", gid, t.Id) if err != nil { Logger.Printf("Error adding many-many mapping for %s\n", t.Email) return err } - err = trans.Commit() + err = trans.Commit().Error if err != nil { Logger.Printf("Error committing db changes\n") return err @@ -175,19 +187,8 @@ func insertTargetIntoGroup(t Target, gid int64) error { return nil } -// DeleteGroup deletes a given group by group ID and user ID -func DeleteGroup(id int64) error { - // Delete all the group_targets entries for this group - _, err := Conn.Exec("DELETE FROM group_targets WHERE gid=?", id) - if err != nil { - return err - } - // Delete the reference to the group in the user_group table - _, err = Conn.Exec("DELETE FROM user_groups WHERE gid=?", id) - if err != nil { - return err - } - // Delete the group itself - _, err = Conn.Exec("DELETE FROM groups WHERE id=?", id) - return err +func GetTargets(gid int64) ([]Target, error) { + ts := []Target{} + err := db.Table("targets t").Select("t.id, t.email").Joins("left join group_targets gt ON t.id = gt.target_id").Where("gt.group_id=?", gid).Scan(&ts).Error + return ts, err } diff --git a/models/models.go b/models/models.go index daf8343d..87cbc372 100644 --- a/models/models.go +++ b/models/models.go @@ -1,74 +1,60 @@ package models import ( - "database/sql" "errors" "log" "os" "github.com/coopernurse/gorp" + "github.com/jinzhu/gorm" "github.com/jordan-wright/gophish/config" _ "github.com/mattn/go-sqlite3" ) var Conn *gorp.DbMap -var DB *sql.DB +var db gorm.DB var err error -var ErrUsernameTaken = errors.New("Username already taken") -var Logger = log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile) - -// Setup initializes the Conn object -// It also populates the Gophish Config object -func Setup() error { - DB, err := sql.Open("sqlite3", config.Conf.DBPath) - Conn = &gorp.DbMap{Db: DB, Dialect: gorp.SqliteDialect{}} - //If the file already exists, delete it and recreate it - _, err = os.Stat(config.Conf.DBPath) - Conn.AddTableWithName(User{}, "users").SetKeys(true, "Id") - Conn.AddTableWithName(Campaign{}, "campaigns").SetKeys(true, "Id") - Conn.AddTableWithName(Group{}, "groups").SetKeys(true, "Id") - Conn.AddTableWithName(Template{}, "templates").SetKeys(true, "Id") - if err != nil { - Logger.Println("Database not found, recreating...") - createTablesSQL := []string{ - //Create tables - `CREATE TABLE users (id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL, hash VARCHAR(60) NOT NULL, api_key VARCHAR(32), UNIQUE(username), UNIQUE(api_key));`, - `CREATE TABLE campaigns (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, created_date TIMESTAMP NOT NULL, completed_date TIMESTAMP, template TEXT, status TEXT NOT NULL);`, - `CREATE TABLE targets (id INTEGER PRIMARY KEY AUTOINCREMENT, email TEXT NOT NULL, UNIQUE(email));`, - `CREATE TABLE groups (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, modified_date TIMESTAMP NOT NULL);`, - `CREATE TABLE campaign_results (cid INTEGER NOT NULL, email TEXT NOT NULL, status TEXT NOT NULL, FOREIGN KEY (cid) REFERENCES campaigns(id), UNIQUE(cid, email, status))`, - `CREATE TABLE templates (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, modified_date TIMESTAMP NOT NULL, html TEXT NOT NULL, text TEXT NOT NULL);`, - `CREATE TABLE files (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, path TEXT NOT NULL);`, - `CREATE TABLE user_campaigns (uid INTEGER NOT NULL, cid INTEGER NOT NULL, FOREIGN KEY (uid) REFERENCES users(id), FOREIGN KEY (cid) REFERENCES campaigns(id), UNIQUE(uid, cid))`, - `CREATE TABLE user_groups (uid INTEGER NOT NULL, gid INTEGER NOT NULL, FOREIGN KEY (uid) REFERENCES users(id), FOREIGN KEY (gid) REFERENCES groups(id), UNIQUE(uid, gid))`, - `CREATE TABLE group_targets (gid INTEGER NOT NULL, tid INTEGER NOT NULL, FOREIGN KEY (gid) REFERENCES groups(id), FOREIGN KEY (tid) REFERENCES targets(id), UNIQUE(gid, tid));`, - `CREATE TABLE user_templates (uid INTEGER NOT NULL, tid INTEGER NOT NULL, FOREIGN KEY (uid) REFERENCES users(id), FOREIGN KEY (tid) REFERENCES templates(id), UNIQUE(uid, tid));`, - `CREATE TABLE template_files (tid INTEGER NOT NULL, fid INTEGER NOT NULL, FOREIGN KEY (tid) REFERENCES templates(id), FOREIGN KEY(fid) REFERENCES files(id), UNIQUE(tid, fid));`, - } - Logger.Printf("Creating db at %s\n", config.Conf.DBPath) - //Create the tables needed - for _, stmt := range createTablesSQL { - _, err = DB.Exec(stmt) - if err != nil { - return err - } - } - //Create the default user - init_user := User{ - Username: "admin", - Hash: "$2a$10$IYkPp0.QsM81lYYPrQx6W.U6oQGw7wMpozrKhKAHUBVL4mkm/EvAS", //gophish - APIKey: "12345678901234567890123456789012", - } - Conn.Insert(&init_user) - if err != nil { - Logger.Println(err) - } - } - return nil -} +var ErrUsernameTaken = errors.New("username already taken") +var Logger = log.New(os.Stdout, " ", log.Ldate|log.Ltime|log.Lshortfile) // Flash is used to hold flash information for use in templates. type Flash struct { Type string Message string } + +// Setup initializes the Conn object +// It also populates the Gophish Config object +func Setup() error { + db, err = gorm.Open("sqlite3", config.Conf.DBPath) + db.LogMode(true) + db.SetLogger(Logger) + if err != nil { + Logger.Println(err) + return err + } + //If the file already exists, delete it and recreate it + _, err = os.Stat(config.Conf.DBPath) + if err != nil { + Logger.Printf("Database not found... creating db at %s\n", config.Conf.DBPath) + db.CreateTable(User{}) + db.CreateTable(Target{}) + db.CreateTable(Result{}) + db.CreateTable(Group{}) + db.CreateTable(GroupTarget{}) + db.CreateTable(Template{}) + db.CreateTable(UserTemplate{}) + db.CreateTable(Campaign{}) + //Create the default user + init_user := User{ + Username: "admin", + Hash: "$2a$10$IYkPp0.QsM81lYYPrQx6W.U6oQGw7wMpozrKhKAHUBVL4mkm/EvAS", //gophish + ApiKey: "12345678901234567890123456789012", + } + err = db.Save(&init_user).Error + if err != nil { + Logger.Println(err) + } + } + return nil +} diff --git a/models/models_test.go b/models/models_test.go new file mode 100644 index 00000000..1bd8bd6d --- /dev/null +++ b/models/models_test.go @@ -0,0 +1,47 @@ +package models + +import ( + "os" + "testing" + + "github.com/jordan-wright/gophish/config" + "launchpad.net/gocheck" +) + +// Hook up gocheck into the "go test" runner. +func Test(t *testing.T) { gocheck.TestingT(t) } + +type ModelsSuite struct{} + +var _ = gocheck.Suite(&ModelsSuite{}) + +func (s *ModelsSuite) SetUpSuite(c *gocheck.C) { + config.Conf.DBPath = "../gophish_test.db" + err := Setup() + if err != nil { + c.Fatalf("Failed creating database: %v", err) + } +} + +func (s *ModelsSuite) TestGetUser(c *gocheck.C) { + u, err := GetUser(1) + c.Assert(err, gocheck.Equals, nil) + c.Assert(u.Username, gocheck.Equals, "admin") +} + +func (s *ModelsSuite) TestPutUser(c *gocheck.C) { + u, err := GetUser(1) + u.Username = "admin_changed" + err = PutUser(&u) + c.Assert(err, gocheck.Equals, nil) + u, err = GetUser(1) + c.Assert(u.Username, gocheck.Equals, "admin_changed") +} + +func (s *ModelsSuite) TearDownSuite(c *gocheck.C) { + db.DB().Close() + err := os.Remove(config.Conf.DBPath) + if err != nil { + c.Fatalf("Failed deleting test database: %v", err) + } +} diff --git a/models/template.go b/models/template.go index 39f01dfd..a26c1d8e 100644 --- a/models/template.go +++ b/models/template.go @@ -4,39 +4,40 @@ import "time" type Template struct { Id int64 `json:"id"` - Name string `json:"name" db:"name"` - Text string `json:"text" db:"text"` - Html string `json:"html" db:"html"` - ModifiedDate time.Time `json:"modified_date" db:"modified_date"` + Name string `json:"name"` + Text string `json:"text"` + Html string `json:"html"` + ModifiedDate time.Time `json:"modified_date"` +} + +type UserTemplate struct { + UserId int64 `json:"-"` + TemplateId int64 `json:"-"` } // GetTemplates returns the templates owned by the given user. func GetTemplates(uid int64) ([]Template, error) { ts := []Template{} - _, err := Conn.Select(&ts, "SELECT t.id, t.name, t.modified_date, t.text, t.html FROM templates t, user_templates ut, users u WHERE ut.uid=u.id AND ut.tid=t.id AND u.id=?", uid) + err := db.Table("templates t").Select("t.*").Joins("left join user_templates ut ON t.id = ut.template_id").Where("ut.user_id=?", uid).Scan(&ts).Error return ts, err } // GetTemplate returns the template, if it exists, specified by the given id and user_id. func GetTemplate(id int64, uid int64) (Template, error) { t := Template{} - err := Conn.SelectOne(&t, "SELECT t.id, t.name, t.modified_date, t.text, t.html FROM templates t, user_templates ut, users u WHERE ut.uid=u.id AND ut.tid=t.id AND t.id=? AND u.id=?", id, uid) - if err != nil { - return t, err - } + err := db.Table("templates t").Select("t.*").Joins("left join user_templates ut ON t.id = ut.template_id").Where("ut.user_id=? and t.id=?", uid, id).Scan(&t).Error return t, err } // PostTemplate creates a new template in the database. func PostTemplate(t *Template, uid int64) error { // Insert into the DB - err = Conn.Insert(t) + err := db.Save(t).Error if err != nil { - Logger.Println(err) return err } // Now, let's add the user->user_templates->template mapping - _, err = Conn.Exec("INSERT OR IGNORE INTO user_templates VALUES (?,?)", uid, t.Id) + err = db.Exec("INSERT OR IGNORE INTO user_templates VALUES (?,?)", uid, t.Id).Error if err != nil { Logger.Printf("Error adding many-many mapping for template %s\n", t.Name) } diff --git a/models/user.go b/models/user.go index 2f329905..a3915c2c 100644 --- a/models/user.go +++ b/models/user.go @@ -5,16 +5,16 @@ import "database/sql" // User represents the user model for gophish. type User struct { Id int64 `json:"id"` - Username string `json:"username"` + Username string `json:"username" sql:"not null;unique"` Hash string `json:"-"` - APIKey string `json:"api_key" db:"api_key"` + ApiKey string `json:"api_key" sql:"not null;unique"` } // GetUser returns the user that the given id corresponds to. If no user is found, an // error is thrown. func GetUser(id int64) (User, error) { u := User{} - err := Conn.SelectOne(&u, "SELECT * FROM Users WHERE id=?", id) + err := db.Where("id=?", id).First(&u).Error if err != nil { return u, err } @@ -23,9 +23,9 @@ func GetUser(id int64) (User, error) { // GetUserByAPIKey returns the user that the given API Key corresponds to. If no user is found, an // error is thrown. -func GetUserByAPIKey(key []byte) (User, error) { +func GetUserByAPIKey(key string) (User, error) { u := User{} - err := Conn.SelectOne(&u, "SELECT id, username, api_key FROM Users WHERE apikey=?", key) + err := db.Where("api_key = ?", key).First(&u).Error if err != nil { return u, err } @@ -36,7 +36,7 @@ func GetUserByAPIKey(key []byte) (User, error) { // error is thrown. func GetUserByUsername(username string) (User, error) { u := User{} - err := Conn.SelectOne(&u, "SELECT * FROM Users WHERE username=?", username) + err := db.Where("username = ?", username).First(&u).Error if err != sql.ErrNoRows { return u, ErrUsernameTaken } else if err != nil { @@ -47,6 +47,6 @@ func GetUserByUsername(username string) (User, error) { // PutUser updates the given user func PutUser(u *User) error { - _, err := Conn.Update(u) + err := db.Save(u).Error return err } diff --git a/templates/base.html b/templates/base.html index fa1404e6..9f7fc38a 100644 --- a/templates/base.html +++ b/templates/base.html @@ -20,7 +20,7 @@ {{%if .User%}} - + {{%end%}} diff --git a/templates/settings.html b/templates/settings.html index 8cc15477..1c02c380 100644 --- a/templates/settings.html +++ b/templates/settings.html @@ -27,7 +27,7 @@