diff --git a/controllers/api.go b/controllers/api.go index a2c5e701..b535ff17 100644 --- a/controllers/api.go +++ b/controllers/api.go @@ -176,7 +176,8 @@ func API_Groups(w http.ResponseWriter, r *http.Request) { return } g.ModifiedDate = time.Now() - err = models.PostGroup(&g, ctx.Get(r, "user_id").(int64)) + g.UserId = ctx.Get(r, "user_id").(int64) + err = models.PostGroup(&g) if checkError(err, w, "Error inserting group", http.StatusInternalServerError) { return } @@ -205,11 +206,11 @@ func API_Groups_Id(w http.ResponseWriter, r *http.Request) { } writeJSON(w, gj) case r.Method == "DELETE": - _, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64)) + g, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64)) if checkError(err, w, "No group found", http.StatusNotFound) { return } - err = models.DeleteGroup(id) + err = models.DeleteGroup(&g) if checkError(err, w, "Error deleting group", http.StatusInternalServerError) { return } @@ -230,7 +231,9 @@ func API_Groups_Id(w http.ResponseWriter, r *http.Request) { http.Error(w, "Error: No targets specified", http.StatusBadRequest) return } - err = models.PutGroup(&g, ctx.Get(r, "user_id").(int64)) + g.ModifiedDate = time.Now() + g.UserId = ctx.Get(r, "user_id").(int64) + err = models.PutGroup(&g) if checkError(err, w, "Error updating group", http.StatusInternalServerError) { return } diff --git a/models/group.go b/models/group.go index 8341dc92..dc80c501 100644 --- a/models/group.go +++ b/models/group.go @@ -9,16 +9,12 @@ import ( type Group struct { Id int64 `json:"id"` + UserId int64 `json:"-"` Name string `json:"name"` ModifiedDate time.Time `json:"modified_date"` Targets []Target `json:"targets" sql:"-"` } -type UserGroup struct { - UserId int64 `json:"-"` - GroupId int64 `json:"-"` -} - type GroupTarget struct { GroupId int64 `json:"-"` TargetId int64 `json:"-"` @@ -32,7 +28,7 @@ type Target struct { // GetGroups returns the groups owned by the given user. func GetGroups(uid int64) ([]Group, error) { gs := []Group{} - err := db.Table("groups g").Select("g.*").Joins("left join user_groups ug ON g.id = ug.group_id").Where("ug.user_id=?", uid).Scan(&gs).Error + err := db.Where("user_id=?", uid).Find(&gs).Error if err != nil { Logger.Println(err) return gs, err @@ -49,7 +45,7 @@ func GetGroups(uid int64) ([]Group, error) { // GetGroup returns the group, if it exists, specified by the given id and user_id. func GetGroup(id int64, uid int64) (Group, error) { g := Group{} - err := db.Table("groups g").Select("g.*").Joins("left join user_groups ug ON g.id = ug.group_id").Where("ug.user_id=? and g.id=?", uid, id).Scan(&g).Error + err := db.Where("user_id=? and id=?", uid, id).Find(&g).Error if err != nil { Logger.Println(err) return g, err @@ -64,7 +60,7 @@ func GetGroup(id int64, uid int64) (Group, error) { // GetGroupByName returns the group, if it exists, specified by the given name and user_id. func GetGroupByName(n string, uid int64) (Group, error) { g := Group{} - err := db.Table("groups g").Select("g.*").Joins("left join user_groups ug ON g.id = ug.group_id").Where("ug.user_id=? and g.name=?", uid, n).Scan(&g).Error + err := db.Where("user_id=? and name=?", uid, n).Find(&g).Error if err != nil { Logger.Println(err) return g, err @@ -77,19 +73,13 @@ func GetGroupByName(n string, uid int64) (Group, error) { } // PostGroup creates a new group in the database. -func PostGroup(g *Group, uid int64) error { +func PostGroup(g *Group) error { // Insert into the DB err = db.Save(g).Error if err != nil { Logger.Println(err) return err } - // Now, let's add the user->user_groups->group mapping - err = db.Save(&UserGroup{GroupId: g.Id, UserId: uid}).Error - if err != nil { - Logger.Println(err) - return err - } for _, t := range g.Targets { insertTargetIntoGroup(t, g.Id) } @@ -97,11 +87,7 @@ func PostGroup(g *Group, uid int64) error { } // PutGroup updates the given group if found in the database. -func PutGroup(g *Group, uid int64) error { - // Update all the foreign keys, and many to many relationships - // We will only delete the group->targets entries. We keep the actual targets - // since they are needed by the Results table - // Get all the targets currently in the database for the group +func PutGroup(g *Group) error { ts := []Target{} ts, err = GetTargets(g.Id) if err != nil { @@ -144,8 +130,6 @@ func PutGroup(g *Group, uid int64) error { insertTargetIntoGroup(nt, g.Id) } } - // Update the group - g.ModifiedDate = time.Now() err = db.Save(g).Error /*_, err = Conn.Update(g)*/ if err != nil { @@ -155,6 +139,23 @@ func PutGroup(g *Group, uid int64) error { return nil } +// DeleteGroup deletes a given group by group ID and user ID +func DeleteGroup(g *Group) error { + // Delete all the group_targets entries for this group + err := db.Where("group_id=?", g.Id).Delete(&GroupTarget{}).Error + if err != nil { + Logger.Println(err) + return err + } + // Delete the group itself + err = db.Delete(g).Error + if err != nil { + Logger.Println(err) + return err + } + return err +} + func insertTargetIntoGroup(t Target, gid int64) error { if _, err = mail.ParseAddress(t.Email); err != nil { Logger.Printf("Invalid email %s\n", t.Email) @@ -162,7 +163,6 @@ func insertTargetIntoGroup(t Target, gid int64) error { } trans := db.Begin() trans.Where(t).FirstOrCreate(&t) - Logger.Printf("ID of Target after FirstOrCreate: %d", t.Id) if err != nil { Logger.Printf("Error adding target: %s\n", t.Email) return err @@ -187,29 +187,6 @@ func insertTargetIntoGroup(t Target, gid int64) error { return nil } -// DeleteGroup deletes a given group by group ID and user ID -func DeleteGroup(id int64) error { - // Delete all the group_targets entries for this group - err := db.Where("group_id=?", id).Delete(&GroupTarget{}).Error - if err != nil { - Logger.Println(err) - return err - } - // Delete the reference to the group in the user_group table - err = db.Where("group_id=?", id).Delete(&UserGroup{}).Error - if err != nil { - Logger.Println(err) - return err - } - // Delete the group itself - err = db.Delete(&Group{Id: id}).Error - if err != nil { - Logger.Println(err) - return err - } - return err -} - func GetTargets(gid int64) ([]Target, error) { ts := []Target{} err := db.Table("targets t").Select("t.id, t.email").Joins("left join group_targets gt ON t.id = gt.target_id").Where("gt.group_id=?", gid).Scan(&ts).Error diff --git a/models/models.go b/models/models.go index ecf572a3..87cbc372 100644 --- a/models/models.go +++ b/models/models.go @@ -42,8 +42,8 @@ func Setup() error { db.CreateTable(Result{}) db.CreateTable(Group{}) db.CreateTable(GroupTarget{}) - db.CreateTable(UserGroup{}) db.CreateTable(Template{}) + db.CreateTable(UserTemplate{}) db.CreateTable(Campaign{}) //Create the default user init_user := User{ diff --git a/models/models_test.go b/models/models_test.go index 6c2d6733..1bd8bd6d 100644 --- a/models/models_test.go +++ b/models/models_test.go @@ -29,6 +29,15 @@ func (s *ModelsSuite) TestGetUser(c *gocheck.C) { c.Assert(u.Username, gocheck.Equals, "admin") } +func (s *ModelsSuite) TestPutUser(c *gocheck.C) { + u, err := GetUser(1) + u.Username = "admin_changed" + err = PutUser(&u) + c.Assert(err, gocheck.Equals, nil) + u, err = GetUser(1) + c.Assert(u.Username, gocheck.Equals, "admin_changed") +} + func (s *ModelsSuite) TearDownSuite(c *gocheck.C) { db.DB().Close() err := os.Remove(config.Conf.DBPath) diff --git a/models/template.go b/models/template.go index eb97476f..a26c1d8e 100644 --- a/models/template.go +++ b/models/template.go @@ -10,33 +10,34 @@ type Template struct { ModifiedDate time.Time `json:"modified_date"` } +type UserTemplate struct { + UserId int64 `json:"-"` + TemplateId int64 `json:"-"` +} + // GetTemplates returns the templates owned by the given user. func GetTemplates(uid int64) ([]Template, error) { ts := []Template{} - _, err := Conn.Select(&ts, "SELECT t.id, t.name, t.modified_date, t.text, t.html FROM templates t, user_templates ut, users u WHERE ut.uid=u.id AND ut.tid=t.id AND u.id=?", uid) + err := db.Table("templates t").Select("t.*").Joins("left join user_templates ut ON t.id = ut.template_id").Where("ut.user_id=?", uid).Scan(&ts).Error return ts, err } // GetTemplate returns the template, if it exists, specified by the given id and user_id. func GetTemplate(id int64, uid int64) (Template, error) { t := Template{} - err := Conn.SelectOne(&t, "SELECT t.id, t.name, t.modified_date, t.text, t.html FROM templates t, user_templates ut, users u WHERE ut.uid=u.id AND ut.tid=t.id AND t.id=? AND u.id=?", id, uid) - if err != nil { - return t, err - } + err := db.Table("templates t").Select("t.*").Joins("left join user_templates ut ON t.id = ut.template_id").Where("ut.user_id=? and t.id=?", uid, id).Scan(&t).Error return t, err } // PostTemplate creates a new template in the database. func PostTemplate(t *Template, uid int64) error { // Insert into the DB - err = Conn.Insert(t) + err := db.Save(t).Error if err != nil { - Logger.Println(err) return err } // Now, let's add the user->user_templates->template mapping - _, err = Conn.Exec("INSERT OR IGNORE INTO user_templates VALUES (?,?)", uid, t.Id) + err = db.Exec("INSERT OR IGNORE INTO user_templates VALUES (?,?)", uid, t.Id).Error if err != nil { Logger.Printf("Error adding many-many mapping for template %s\n", t.Name) } diff --git a/models/user.go b/models/user.go index 9f16a1a6..a3915c2c 100644 --- a/models/user.go +++ b/models/user.go @@ -47,6 +47,6 @@ func GetUserByUsername(username string) (User, error) { // PutUser updates the given user func PutUser(u *User) error { - err := db.Update(&u).Error + err := db.Save(u).Error return err }