diff --git a/controllers/api/group.go b/controllers/api/group.go index c4bdfbeb..94d1a7b4 100644 --- a/controllers/api/group.go +++ b/controllers/api/group.go @@ -163,45 +163,49 @@ func (as *Server) Group(w http.ResponseWriter, r *http.Request) { g = models.Group{} //Check if content is CSV - var csvmode = false contentType := r.Header.Get("Content-Type") if strings.HasPrefix(contentType, "multipart/form-data") { - csvmode = true - targets, groupname, err := util.ParseCSV(r) + targets, _, err := util.ParseCSV(r) if err != nil { JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusBadRequest) return } - // We need to fetch all the existing targets for this group, so as to not overwrite them below - et, _ := models.GetTargets(id) - g.Targets = append(targets, et...) - g.Name = groupname - g.Id = id // ID isn't supplied in the CSV file upload. Perhaps we could use the filename paramter for this? I'm not sure if this is necessary though. - } else { // else JSON - err = json.NewDecoder(r.Body).Decode(&g) + err = models.AddTargetsToGroup(targets, id) if err != nil { - log.Errorf("error decoding group: %v", err) - JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusInternalServerError) + log.Errorf("error add targets to group: %v", err) + JSONResponse(w, models.Response{Success: false, Message: "Unable to add targets to group!"}, http.StatusBadRequest) return } + // With CSV we don't return the entire target list, in line with the new pagination server side processing. + ng, err := models.GetGroupSummary(id, ctx.Get(r, "user_id").(int64)) + if err != nil { + JSONResponse(w, models.Response{Success: false, Message: "Group not found"}, http.StatusNotFound) + return + } + JSONResponse(w, ng, http.StatusCreated) + return } + + // Default JSON + err = json.NewDecoder(r.Body).Decode(&g) + if err != nil { + log.Errorf("error decoding group: %v", err) + JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusInternalServerError) + return + } + if g.Id != id { JSONResponse(w, models.Response{Success: false, Message: "Error: /:id and group_id mismatch"}, http.StatusInternalServerError) return } g.ModifiedDate = time.Now().UTC() g.UserId = ctx.Get(r, "user_id").(int64) + err = models.PutGroup(&g) if err != nil { JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusBadRequest) return } - // With CSV we don't return the entire target list, in line with the new pagination server side processing. To maintain backwards API capabiltiy the JSON request - // will still return the full list. - if csvmode == true { - JSONResponse(w, models.GroupSummary{Id: g.Id, Name: g.Name, ModifiedDate: g.ModifiedDate, NumTargets: int64(len(g.Targets))}, http.StatusCreated) - return - } JSONResponse(w, g, http.StatusOK) } } @@ -241,7 +245,7 @@ func (as *Server) GroupTarget(w http.ResponseWriter, r *http.Request) { switch { case r.Method == "PUT": // Add an individual target to a group - err = models.AddTargetToGroup(t, gid) + err = models.AddTargetsToGroup([]models.Target{t}, gid) if err != nil { JSONResponse(w, models.Response{Success: false, Message: "Unable to add target to group"}, http.StatusNotFound) return diff --git a/models/group.go b/models/group.go index 27525953..7860ccb6 100644 --- a/models/group.go +++ b/models/group.go @@ -357,25 +357,53 @@ func UpdateGroup(g *Group) error { return err } -// AddTargetToGroup adds a single given target to a group by group ID -func AddTargetToGroup(nt Target, gid int64) error { - // Check if target already exists in group - tmpt, err := GetTargetByEmail(gid, nt.Email) +// AddTargetsToGroup adds targets to a group, updating on duplicate email +func AddTargetsToGroup(nts []Target, gid int64) error { + + // Fetch group's existing targets from database. + ets, err := GetTargets(gid) if err != nil { return err } + // Load email to target id cache + existingTargetCache := make(map[string]int64, len(ets)) + for _, t := range ets { + existingTargetCache[t.Email] = t.Id + } + + // Step over each new target and see if it exists in the cache map. + tx := db.Begin() + for _, nt := range nts { + if _, ok := existingTargetCache[nt.Email]; ok { + // Update + nt.Id = existingTargetCache[nt.Email] + err = UpdateTarget(tx, nt) + if err != nil { + log.Error(err) + tx.Rollback() + return err + } + } else { + // Otherwise, add target if not in database + err = insertTargetIntoGroup(tx, nt, gid) + if err != nil { + log.Error(err) + tx.Rollback() + return err + } + } + } // for each new target + + err = tx.Model(&Group{}).Where("id=?", gid).Update("ModifiedDate", time.Now().UTC()).Error // put this in the tx too TODO + if err != nil { + tx.Rollback() + return err + } - // If target exists in group, update it. - if len(tmpt) > 0 { - nt.Id = tmpt[0].Id - err = UpdateTarget(db, nt) - } else { - err = insertTargetIntoGroup(db, nt, gid) - } + err = tx.Commit().Error if err != nil { - return err + tx.Rollback() } - err = db.Model(&Group{}).Where("id=?", gid).Update("ModifiedDate", time.Now().UTC()).Error return err }