mirror of https://github.com/gophish/gophish
Updated how targets are added to a group, improving efficiency
parent
eb332880ab
commit
3692fcc680
|
@ -163,45 +163,49 @@ func (as *Server) Group(w http.ResponseWriter, r *http.Request) {
|
||||||
g = models.Group{}
|
g = models.Group{}
|
||||||
|
|
||||||
//Check if content is CSV
|
//Check if content is CSV
|
||||||
var csvmode = false
|
|
||||||
contentType := r.Header.Get("Content-Type")
|
contentType := r.Header.Get("Content-Type")
|
||||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||||
csvmode = true
|
targets, _, err := util.ParseCSV(r)
|
||||||
targets, groupname, err := util.ParseCSV(r)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusBadRequest)
|
JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// We need to fetch all the existing targets for this group, so as to not overwrite them below
|
err = models.AddTargetsToGroup(targets, id)
|
||||||
et, _ := models.GetTargets(id)
|
if err != nil {
|
||||||
g.Targets = append(targets, et...)
|
log.Errorf("error add targets to group: %v", err)
|
||||||
g.Name = groupname
|
JSONResponse(w, models.Response{Success: false, Message: "Unable to add targets to group!"}, http.StatusBadRequest)
|
||||||
g.Id = id // ID isn't supplied in the CSV file upload. Perhaps we could use the filename paramter for this? I'm not sure if this is necessary though.
|
return
|
||||||
} else { // else JSON
|
}
|
||||||
|
// With CSV we don't return the entire target list, in line with the new pagination server side processing.
|
||||||
|
ng, err := models.GetGroupSummary(id, ctx.Get(r, "user_id").(int64))
|
||||||
|
if err != nil {
|
||||||
|
JSONResponse(w, models.Response{Success: false, Message: "Group not found"}, http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
JSONResponse(w, ng, http.StatusCreated)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default JSON
|
||||||
err = json.NewDecoder(r.Body).Decode(&g)
|
err = json.NewDecoder(r.Body).Decode(&g)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error decoding group: %v", err)
|
log.Errorf("error decoding group: %v", err)
|
||||||
JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusInternalServerError)
|
JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if g.Id != id {
|
if g.Id != id {
|
||||||
JSONResponse(w, models.Response{Success: false, Message: "Error: /:id and group_id mismatch"}, http.StatusInternalServerError)
|
JSONResponse(w, models.Response{Success: false, Message: "Error: /:id and group_id mismatch"}, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
g.ModifiedDate = time.Now().UTC()
|
g.ModifiedDate = time.Now().UTC()
|
||||||
g.UserId = ctx.Get(r, "user_id").(int64)
|
g.UserId = ctx.Get(r, "user_id").(int64)
|
||||||
|
|
||||||
err = models.PutGroup(&g)
|
err = models.PutGroup(&g)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusBadRequest)
|
JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// With CSV we don't return the entire target list, in line with the new pagination server side processing. To maintain backwards API capabiltiy the JSON request
|
|
||||||
// will still return the full list.
|
|
||||||
if csvmode == true {
|
|
||||||
JSONResponse(w, models.GroupSummary{Id: g.Id, Name: g.Name, ModifiedDate: g.ModifiedDate, NumTargets: int64(len(g.Targets))}, http.StatusCreated)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
JSONResponse(w, g, http.StatusOK)
|
JSONResponse(w, g, http.StatusOK)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -241,7 +245,7 @@ func (as *Server) GroupTarget(w http.ResponseWriter, r *http.Request) {
|
||||||
switch {
|
switch {
|
||||||
case r.Method == "PUT":
|
case r.Method == "PUT":
|
||||||
// Add an individual target to a group
|
// Add an individual target to a group
|
||||||
err = models.AddTargetToGroup(t, gid)
|
err = models.AddTargetsToGroup([]models.Target{t}, gid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
JSONResponse(w, models.Response{Success: false, Message: "Unable to add target to group"}, http.StatusNotFound)
|
JSONResponse(w, models.Response{Success: false, Message: "Unable to add target to group"}, http.StatusNotFound)
|
||||||
return
|
return
|
||||||
|
|
|
@ -357,25 +357,53 @@ func UpdateGroup(g *Group) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddTargetToGroup adds a single given target to a group by group ID
|
// AddTargetsToGroup adds targets to a group, updating on duplicate email
|
||||||
func AddTargetToGroup(nt Target, gid int64) error {
|
func AddTargetsToGroup(nts []Target, gid int64) error {
|
||||||
// Check if target already exists in group
|
|
||||||
tmpt, err := GetTargetByEmail(gid, nt.Email)
|
// Fetch group's existing targets from database.
|
||||||
|
ets, err := GetTargets(gid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// Load email to target id cache
|
||||||
|
existingTargetCache := make(map[string]int64, len(ets))
|
||||||
|
for _, t := range ets {
|
||||||
|
existingTargetCache[t.Email] = t.Id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step over each new target and see if it exists in the cache map.
|
||||||
|
tx := db.Begin()
|
||||||
|
for _, nt := range nts {
|
||||||
|
if _, ok := existingTargetCache[nt.Email]; ok {
|
||||||
|
// Update
|
||||||
|
nt.Id = existingTargetCache[nt.Email]
|
||||||
|
err = UpdateTarget(tx, nt)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Otherwise, add target if not in database
|
||||||
|
err = insertTargetIntoGroup(tx, nt, gid)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // for each new target
|
||||||
|
|
||||||
|
err = tx.Model(&Group{}).Where("id=?", gid).Update("ModifiedDate", time.Now().UTC()).Error // put this in the tx too TODO
|
||||||
|
if err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// If target exists in group, update it.
|
err = tx.Commit().Error
|
||||||
if len(tmpt) > 0 {
|
|
||||||
nt.Id = tmpt[0].Id
|
|
||||||
err = UpdateTarget(db, nt)
|
|
||||||
} else {
|
|
||||||
err = insertTargetIntoGroup(db, nt, gid)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
tx.Rollback()
|
||||||
}
|
}
|
||||||
err = db.Model(&Group{}).Where("id=?", gid).Update("ModifiedDate", time.Now().UTC()).Error
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue