mirror of https://github.com/gophish/gophish
Reverted models.GetGroup to original and created new GetDataTable function to handle pagination requests. Neater, and allowed for retrieval of filter count when searching
parent
81c979886d
commit
b9374aaffe
|
@ -99,12 +99,10 @@ func (as *Server) Group(w http.ResponseWriter, r *http.Request) {
|
||||||
// Paramters passed by DataTables for pagination are handled below
|
// Paramters passed by DataTables for pagination are handled below
|
||||||
v := r.URL.Query()
|
v := r.URL.Query()
|
||||||
search := v.Get("search[value]")
|
search := v.Get("search[value]")
|
||||||
|
|
||||||
sortcolumn := v.Get("order[0][column]")
|
sortcolumn := v.Get("order[0][column]")
|
||||||
sortdir := v.Get("order[0][dir]")
|
sortdir := v.Get("order[0][dir]")
|
||||||
sortby := v.Get("columns[" + sortcolumn + "][data]")
|
sortby := v.Get("columns[" + sortcolumn + "][data]")
|
||||||
order := sortby + " " + sortdir // e.g "first_name asc"
|
order := sortby + " " + sortdir // e.g "first_name asc"
|
||||||
|
|
||||||
start, err := strconv.ParseInt(v.Get("start"), 0, 64)
|
start, err := strconv.ParseInt(v.Get("start"), 0, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
start = -1 // Default. gorm will ignore with this value.
|
start = -1 // Default. gorm will ignore with this value.
|
||||||
|
@ -118,11 +116,23 @@ func (as *Server) Group(w http.ResponseWriter, r *http.Request) {
|
||||||
draw = -1 // If the draw value is missing we can assume this is not a DataTable request and return regular API result
|
draw = -1 // If the draw value is missing we can assume this is not a DataTable request and return regular API result
|
||||||
}
|
}
|
||||||
|
|
||||||
g, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64), start, length, search, order)
|
var g models.Group
|
||||||
if err != nil {
|
if draw == -1 {
|
||||||
JSONResponse(w, models.Response{Success: false, Message: "Group not found"}, http.StatusNotFound)
|
g, err = models.GetGroup(id, ctx.Get(r, "user_id").(int64))
|
||||||
return
|
if err != nil {
|
||||||
|
JSONResponse(w, models.Response{Success: false, Message: "Group not found"}, http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// We don't want to fetch the whole set of targets from a group if we're handling a pagination request. This call
|
||||||
|
// is just to validate group ownership
|
||||||
|
_, err := models.GetGroupSummary(id, ctx.Get(r, "user_id").(int64))
|
||||||
|
if err != nil {
|
||||||
|
JSONResponse(w, models.Response{Success: false, Message: "Group not found"}, http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case r.Method == "GET":
|
case r.Method == "GET":
|
||||||
|
|
||||||
|
@ -131,12 +141,13 @@ func (as *Server) Group(w http.ResponseWriter, r *http.Request) {
|
||||||
JSONResponse(w, g, http.StatusOK)
|
JSONResponse(w, g, http.StatusOK)
|
||||||
} else {
|
} else {
|
||||||
// Handle pagination for DataTable
|
// Handle pagination for DataTable
|
||||||
gs, _ := models.GetGroupSummary(id, ctx.Get(r, "user_id").(int64)) // We need to get the total number of records of the group
|
dT, err := models.GetDataTable(id, start, length, search, order)
|
||||||
dT := models.DataTable{Draw: draw, RecordsTotal: gs.NumTargets, RecordsFiltered: int64(len(g.Targets))}
|
if err != nil {
|
||||||
dT.Data = make([]interface{}, len(g.Targets)) // Pseudocode of 'dT.Data = g.Targets'. https://golang.org/doc/faq#convert_slice_of_interface
|
log.Errorf("error fetching datatable: %v", err)
|
||||||
for i, v := range g.Targets {
|
JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusInternalServerError)
|
||||||
dT.Data[i] = v
|
return
|
||||||
}
|
}
|
||||||
|
dT.Draw = draw
|
||||||
JSONResponse(w, dT, http.StatusOK)
|
JSONResponse(w, dT, http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,7 +173,7 @@ func (as *Server) Group(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// We need to fetch all the existing targets for this group, so as to not overwrite them below
|
// We need to fetch all the existing targets for this group, so as to not overwrite them below
|
||||||
et, _ := models.GetTargets(id, -1, -1, "", "")
|
et, _ := models.GetTargets(id)
|
||||||
g.Targets = append(targets, et...)
|
g.Targets = append(targets, et...)
|
||||||
g.Name = groupname
|
g.Name = groupname
|
||||||
g.Id = id // ID isn't supplied in the CSV file upload. Perhaps we could use the filename paramter for this? I'm not sure if this is necessary though.
|
g.Id = id // ID isn't supplied in the CSV file upload. Perhaps we could use the filename paramter for this? I'm not sure if this is necessary though.
|
||||||
|
@ -216,7 +227,7 @@ func (as *Server) GroupTarget(w http.ResponseWriter, r *http.Request) {
|
||||||
gid, _ := strconv.ParseInt(vars["id"], 0, 64) // group id
|
gid, _ := strconv.ParseInt(vars["id"], 0, 64) // group id
|
||||||
|
|
||||||
// Ensure the group belongs to the user
|
// Ensure the group belongs to the user
|
||||||
_, err := models.GetGroup(gid, ctx.Get(r, "user_id").(int64), 0, 0, "", "")
|
_, err := models.GetGroupSummary(gid, ctx.Get(r, "user_id").(int64))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
JSONResponse(w, models.Response{Success: false, Message: "Group not found"}, http.StatusNotFound)
|
JSONResponse(w, models.Response{Success: false, Message: "Group not found"}, http.StatusNotFound)
|
||||||
return
|
return
|
||||||
|
@ -252,7 +263,7 @@ func (as *Server) GroupRename(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
id, _ := strconv.ParseInt(vars["id"], 0, 64) // group id
|
id, _ := strconv.ParseInt(vars["id"], 0, 64) // group id
|
||||||
g, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64), 0, 0, "", "")
|
g, err := models.GetGroup(id, ctx.Get(r, "user_id").(int64))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
JSONResponse(w, models.Response{Success: false, Message: "Group not found"}, http.StatusNotFound)
|
JSONResponse(w, models.Response{Success: false, Message: "Group not found"}, http.StatusNotFound)
|
||||||
return
|
return
|
||||||
|
|
|
@ -124,7 +124,7 @@ func GetGroups(uid int64) ([]Group, error) {
|
||||||
return gs, err
|
return gs, err
|
||||||
}
|
}
|
||||||
for i := range gs {
|
for i := range gs {
|
||||||
gs[i].Targets, err = GetTargets(gs[i].Id, -1, -1, "", "")
|
gs[i].Targets, err = GetTargets(gs[i].Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
|
@ -154,15 +154,14 @@ func GetGroupSummaries(uid int64) (GroupSummaries, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGroup returns the group, if it exists, specified by the given id and user_id.
|
// GetGroup returns the group, if it exists, specified by the given id and user_id.
|
||||||
// Filter on number of results and starting point with 'start' and 'length' for pagination.
|
func GetGroup(id int64, uid int64) (Group, error) {
|
||||||
func GetGroup(id int64, uid int64, start int64, length int64, search string, order string) (Group, error) {
|
|
||||||
g := Group{}
|
g := Group{}
|
||||||
err := db.Where("user_id=? and id=?", uid, id).Find(&g).Error
|
err := db.Where("user_id=? and id=?", uid, id).Find(&g).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return g, err
|
return g, err
|
||||||
}
|
}
|
||||||
g.Targets, err = GetTargets(g.Id, start, length, search, order)
|
g.Targets, err = GetTargets(g.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
|
@ -194,7 +193,7 @@ func GetGroupByName(n string, uid int64) (Group, error) {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return g, err
|
return g, err
|
||||||
}
|
}
|
||||||
g.Targets, err = GetTargets(g.Id, -1, -1, "", "")
|
g.Targets, err = GetTargets(g.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
|
@ -237,7 +236,7 @@ func PutGroup(g *Group) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Fetch group's existing targets from database.
|
// Fetch group's existing targets from database.
|
||||||
ts, err := GetTargets(g.Id, -1, -1, "", "")
|
ts, err := GetTargets(g.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(logrus.Fields{
|
log.WithFields(logrus.Fields{
|
||||||
"group_id": g.Id,
|
"group_id": g.Id,
|
||||||
|
@ -425,12 +424,18 @@ func UpdateTarget(tx *gorm.DB, target Target) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTargets performs a many-to-many select to get all the Targets for a Group
|
// GetTargets performs a many-to-many select to get all the Targets for a Group
|
||||||
// Start, length, and search can be supplied, or -1, -1, "" to ignore
|
func GetTargets(gid int64) ([]Target, error) {
|
||||||
func GetTargets(gid int64, start int64, length int64, search string, order string) ([]Target, error) {
|
|
||||||
|
|
||||||
ts := []Target{}
|
ts := []Target{}
|
||||||
var err error
|
err := db.Table("targets").Select("targets.id, targets.email, targets.first_name, targets.last_name, targets.position").Joins("left join group_targets gt ON targets.id = gt.target_id").Where("gt.group_id=?", gid).Scan(&ts).Error
|
||||||
|
return ts, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDataTable performs a many-to-many select to get all the Targets for a Group with supplied filters
|
||||||
|
// start, length, and search, order can be supplied, or -1, -1, "", "" to ignore
|
||||||
|
func GetDataTable(gid int64, start int64, length int64, search string, order string) (DataTable, error) {
|
||||||
|
|
||||||
|
dt := DataTable{}
|
||||||
|
ts := []Target{}
|
||||||
order = strings.TrimSpace(order)
|
order = strings.TrimSpace(order)
|
||||||
search = strings.TrimSpace(search)
|
search = strings.TrimSpace(search)
|
||||||
if order == "" {
|
if order == "" {
|
||||||
|
@ -439,16 +444,35 @@ func GetTargets(gid int64, start int64, length int64, search string, order strin
|
||||||
order = "targets." + order
|
order = "targets." + order
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 1. Get the total number of targets in group:
|
||||||
|
err := db.Table("group_targets").Where("group_id=?", gid).Count(&dt.RecordsTotal).Error
|
||||||
|
if err != nil {
|
||||||
|
return dt, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Fetch targets, applying relevant start, length, search, and order paramters.
|
||||||
// TODO: Rather than having two queries create a partial query and include the search options. Haven't been able to figure out how yet.
|
// TODO: Rather than having two queries create a partial query and include the search options. Haven't been able to figure out how yet.
|
||||||
if search != "" {
|
if search != "" {
|
||||||
|
var count int64
|
||||||
search = "%" + search + "%"
|
search = "%" + search + "%"
|
||||||
err = db.Order(order).Table("targets").Select("targets.id, targets.email, targets.first_name, targets.last_name, targets.position").Joins("left join group_targets gt ON targets.id = gt.target_id").Where("gt.group_id=?", gid).Where("targets.first_name LIKE ? OR targets.last_name LIKE ? OR targets.email LIKE ? or targets.position LIKE ?", search, search, search, search).Offset(start).Limit(length).Scan(&ts).Error
|
|
||||||
|
// 2.1 Apply search filter
|
||||||
|
err = db.Order(order).Table("targets").Select("targets.id, targets.email, targets.first_name, targets.last_name, targets.position").Joins("left join group_targets gt ON targets.id = gt.target_id").Where("gt.group_id=?", gid).Where("targets.first_name LIKE ? OR targets.last_name LIKE ? OR targets.email LIKE ? or targets.position LIKE ?", search, search, search, search).Count(&count).Offset(start).Limit(length).Scan(&ts).Error
|
||||||
|
|
||||||
|
dt.RecordsFiltered = count // The number of results from applying the search filter (calculated before trimming down the results with offset and limit)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
err = db.Order(order).Table("targets").Select("targets.id, targets.email, targets.first_name, targets.last_name, targets.position").Joins("left join group_targets gt ON targets.id = gt.target_id").Where("gt.group_id=?", gid).Offset(start).Limit(length).Scan(&ts).Error
|
err = db.Order(order).Table("targets").Select("targets.id, targets.email, targets.first_name, targets.last_name, targets.position").Joins("left join group_targets gt ON targets.id = gt.target_id").Where("gt.group_id=?", gid).Offset(start).Limit(length).Scan(&ts).Error
|
||||||
|
dt.RecordsFiltered = dt.RecordsTotal
|
||||||
}
|
}
|
||||||
|
|
||||||
return ts, err
|
// 3. Insert targes into datatable struct
|
||||||
|
dt.Data = make([]interface{}, len(ts)) // Pseudocode of 'dT.Data = g.Targets'. https://golang.org/doc/faq#convert_slice_of_interface
|
||||||
|
for i, v := range ts {
|
||||||
|
dt.Data[i] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return dt, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTargetByEmail gets a single target from a group by email address and group id
|
// GetTargetByEmail gets a single target from a group by email address and group id
|
||||||
|
|
Loading…
Reference in New Issue