mirror of https://github.com/gophish/gophish
Reverted models.GetGroup to original and created new GetDataTable function to handle pagination requests. Neater, and allowed for retrieval of filter count when searching
parent
81c979886d
commit
b9374aaffe
|
@ -99,12 +99,10 @@ func (as *Server) Group(w http.ResponseWriter, r *http.Request) {
|
|||
// Paramters passed by DataTables for pagination are handled below
|
||||
v := r.URL.Query()
|
||||
search := v.Get("search[value]")
|
||||
|
||||
sortcolumn := v.Get("order[0][column]")
|
||||
sortdir := v.Get("order[0][dir]")
|
||||
sortby := v.Get("columns[" + sortcolumn + "][data]")
|
||||
order := sortby + " " + sortdir // e.g "first_name asc"
|
||||
|
||||
start, err := strconv.ParseInt(v.Get("start"), 0, 64)
|
||||
if err != nil {
|
||||
start = -1 // Default. gorm will ignore with this value.
|
||||
|
@ -118,11 +116,23 @@ func (as *Server) Group(w http.ResponseWriter, r *http.Request) {
|
|||
draw = -1 // If the draw value is missing we can assume this is not a DataTable request and return regular API result
|
||||
}
|
||||
|
||||
g, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64), start, length, search, order)
|
||||
if err != nil {
|
||||
JSONResponse(w, models.Response{Success: false, Message: "Group not found"}, http.StatusNotFound)
|
||||
return
|
||||
var g models.Group
|
||||
if draw == -1 {
|
||||
g, err = models.GetGroup(id, ctx.Get(r, "user_id").(int64))
|
||||
if err != nil {
|
||||
JSONResponse(w, models.Response{Success: false, Message: "Group not found"}, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// We don't want to fetch the whole set of targets from a group if we're handling a pagination request. This call
|
||||
// is just to validate group ownership
|
||||
_, err := models.GetGroupSummary(id, ctx.Get(r, "user_id").(int64))
|
||||
if err != nil {
|
||||
JSONResponse(w, models.Response{Success: false, Message: "Group not found"}, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case r.Method == "GET":
|
||||
|
||||
|
@ -131,12 +141,13 @@ func (as *Server) Group(w http.ResponseWriter, r *http.Request) {
|
|||
JSONResponse(w, g, http.StatusOK)
|
||||
} else {
|
||||
// Handle pagination for DataTable
|
||||
gs, _ := models.GetGroupSummary(id, ctx.Get(r, "user_id").(int64)) // We need to get the total number of records of the group
|
||||
dT := models.DataTable{Draw: draw, RecordsTotal: gs.NumTargets, RecordsFiltered: int64(len(g.Targets))}
|
||||
dT.Data = make([]interface{}, len(g.Targets)) // Pseudocode of 'dT.Data = g.Targets'. https://golang.org/doc/faq#convert_slice_of_interface
|
||||
for i, v := range g.Targets {
|
||||
dT.Data[i] = v
|
||||
dT, err := models.GetDataTable(id, start, length, search, order)
|
||||
if err != nil {
|
||||
log.Errorf("error fetching datatable: %v", err)
|
||||
JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
dT.Draw = draw
|
||||
JSONResponse(w, dT, http.StatusOK)
|
||||
}
|
||||
|
||||
|
@ -162,7 +173,7 @@ func (as *Server) Group(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
// We need to fetch all the existing targets for this group, so as to not overwrite them below
|
||||
et, _ := models.GetTargets(id, -1, -1, "", "")
|
||||
et, _ := models.GetTargets(id)
|
||||
g.Targets = append(targets, et...)
|
||||
g.Name = groupname
|
||||
g.Id = id // ID isn't supplied in the CSV file upload. Perhaps we could use the filename paramter for this? I'm not sure if this is necessary though.
|
||||
|
@ -216,7 +227,7 @@ func (as *Server) GroupTarget(w http.ResponseWriter, r *http.Request) {
|
|||
gid, _ := strconv.ParseInt(vars["id"], 0, 64) // group id
|
||||
|
||||
// Ensure the group belongs to the user
|
||||
_, err := models.GetGroup(gid, ctx.Get(r, "user_id").(int64), 0, 0, "", "")
|
||||
_, err := models.GetGroupSummary(gid, ctx.Get(r, "user_id").(int64))
|
||||
if err != nil {
|
||||
JSONResponse(w, models.Response{Success: false, Message: "Group not found"}, http.StatusNotFound)
|
||||
return
|
||||
|
@ -252,7 +263,7 @@ func (as *Server) GroupRename(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
vars := mux.Vars(r)
|
||||
id, _ := strconv.ParseInt(vars["id"], 0, 64) // group id
|
||||
g, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64), 0, 0, "", "")
|
||||
g, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64))
|
||||
if err != nil {
|
||||
JSONResponse(w, models.Response{Success: false, Message: "Group not found"}, http.StatusNotFound)
|
||||
return
|
||||
|
|
|
@ -124,7 +124,7 @@ func GetGroups(uid int64) ([]Group, error) {
|
|||
return gs, err
|
||||
}
|
||||
for i := range gs {
|
||||
gs[i].Targets, err = GetTargets(gs[i].Id, -1, -1, "", "")
|
||||
gs[i].Targets, err = GetTargets(gs[i].Id)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
@ -154,15 +154,14 @@ func GetGroupSummaries(uid int64) (GroupSummaries, error) {
|
|||
}
|
||||
|
||||
// GetGroup returns the group, if it exists, specified by the given id and user_id.
|
||||
// Filter on number of results and starting point with 'start' and 'length' for pagination.
|
||||
func GetGroup(id int64, uid int64, start int64, length int64, search string, order string) (Group, error) {
|
||||
func GetGroup(id int64, uid int64) (Group, error) {
|
||||
g := Group{}
|
||||
err := db.Where("user_id=? and id=?", uid, id).Find(&g).Error
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return g, err
|
||||
}
|
||||
g.Targets, err = GetTargets(g.Id, start, length, search, order)
|
||||
g.Targets, err = GetTargets(g.Id)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
@ -194,7 +193,7 @@ func GetGroupByName(n string, uid int64) (Group, error) {
|
|||
log.Error(err)
|
||||
return g, err
|
||||
}
|
||||
g.Targets, err = GetTargets(g.Id, -1, -1, "", "")
|
||||
g.Targets, err = GetTargets(g.Id)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
@ -237,7 +236,7 @@ func PutGroup(g *Group) error {
|
|||
return err
|
||||
}
|
||||
// Fetch group's existing targets from database.
|
||||
ts, err := GetTargets(g.Id, -1, -1, "", "")
|
||||
ts, err := GetTargets(g.Id)
|
||||
if err != nil {
|
||||
log.WithFields(logrus.Fields{
|
||||
"group_id": g.Id,
|
||||
|
@ -425,12 +424,18 @@ func UpdateTarget(tx *gorm.DB, target Target) error {
|
|||
}
|
||||
|
||||
// GetTargets performs a many-to-many select to get all the Targets for a Group
|
||||
// Start, length, and search can be supplied, or -1, -1, "" to ignore
|
||||
func GetTargets(gid int64, start int64, length int64, search string, order string) ([]Target, error) {
|
||||
|
||||
func GetTargets(gid int64) ([]Target, error) {
|
||||
ts := []Target{}
|
||||
var err error
|
||||
err := db.Table("targets").Select("targets.id, targets.email, targets.first_name, targets.last_name, targets.position").Joins("left join group_targets gt ON targets.id = gt.target_id").Where("gt.group_id=?", gid).Scan(&ts).Error
|
||||
return ts, err
|
||||
}
|
||||
|
||||
// GetDataTable performs a many-to-many select to get all the Targets for a Group with supplied filters
|
||||
// start, length, and search, order can be supplied, or -1, -1, "", "" to ignore
|
||||
func GetDataTable(gid int64, start int64, length int64, search string, order string) (DataTable, error) {
|
||||
|
||||
dt := DataTable{}
|
||||
ts := []Target{}
|
||||
order = strings.TrimSpace(order)
|
||||
search = strings.TrimSpace(search)
|
||||
if order == "" {
|
||||
|
@ -439,16 +444,35 @@ func GetTargets(gid int64, start int64, length int64, search string, order strin
|
|||
order = "targets." + order
|
||||
}
|
||||
|
||||
// 1. Get the total number of targets in group:
|
||||
err := db.Table("group_targets").Where("group_id=?", gid).Count(&dt.RecordsTotal).Error
|
||||
if err != nil {
|
||||
return dt, err
|
||||
}
|
||||
|
||||
// 2. Fetch targets, applying relevant start, length, search, and order paramters.
|
||||
// TODO: Rather than having two queries create a partial query and include the search options. Haven't been able to figure out how yet.
|
||||
if search != "" {
|
||||
var count int64
|
||||
search = "%" + search + "%"
|
||||
err = db.Order(order).Table("targets").Select("targets.id, targets.email, targets.first_name, targets.last_name, targets.position").Joins("left join group_targets gt ON targets.id = gt.target_id").Where("gt.group_id=?", gid).Where("targets.first_name LIKE ? OR targets.last_name LIKE ? OR targets.email LIKE ? or targets.position LIKE ?", search, search, search, search).Offset(start).Limit(length).Scan(&ts).Error
|
||||
|
||||
// 2.1 Apply search filter
|
||||
err = db.Order(order).Table("targets").Select("targets.id, targets.email, targets.first_name, targets.last_name, targets.position").Joins("left join group_targets gt ON targets.id = gt.target_id").Where("gt.group_id=?", gid).Where("targets.first_name LIKE ? OR targets.last_name LIKE ? OR targets.email LIKE ? or targets.position LIKE ?", search, search, search, search).Count(&count).Offset(start).Limit(length).Scan(&ts).Error
|
||||
|
||||
dt.RecordsFiltered = count // The number of results from applying the search filter (calculated before trimming down the results with offset and limit)
|
||||
|
||||
} else {
|
||||
err = db.Order(order).Table("targets").Select("targets.id, targets.email, targets.first_name, targets.last_name, targets.position").Joins("left join group_targets gt ON targets.id = gt.target_id").Where("gt.group_id=?", gid).Offset(start).Limit(length).Scan(&ts).Error
|
||||
dt.RecordsFiltered = dt.RecordsTotal
|
||||
}
|
||||
|
||||
return ts, err
|
||||
// 3. Insert targes into datatable struct
|
||||
dt.Data = make([]interface{}, len(ts)) // Pseudocode of 'dT.Data = g.Targets'. https://golang.org/doc/faq#convert_slice_of_interface
|
||||
for i, v := range ts {
|
||||
dt.Data[i] = v
|
||||
}
|
||||
|
||||
return dt, err
|
||||
}
|
||||
|
||||
// GetTargetByEmail gets a single target from a group by email address and group id
|
||||
|
|
Loading…
Reference in New Issue