diff --git a/controllers/phish.go b/controllers/phish.go index 57677c6b..fd2fefdf 100644 --- a/controllers/phish.go +++ b/controllers/phish.go @@ -36,7 +36,8 @@ type eventDetails struct { // CreatePhishingRouter creates the router that handles phishing connections. func CreatePhishingRouter() http.Handler { router := mux.NewRouter() - router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir("./static/endpoint/")))) + fileServer := http.FileServer(UnindexedFileSystem{http.Dir("./static/endpoint/")}) + router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", fileServer)) router.HandleFunc("/track", PhishTracker) router.HandleFunc("/robots.txt", RobotsHandler) router.HandleFunc("/{path:.*}/track", PhishTracker) diff --git a/controllers/route.go b/controllers/route.go index 94bb734c..6ce76a60 100644 --- a/controllers/route.go +++ b/controllers/route.go @@ -60,7 +60,7 @@ func CreateAdminRouter() http.Handler { api.HandleFunc("/import/site", Use(API_Import_Site, mid.RequireAPIKey)) // Setup static file serving - router.PathPrefix("/").Handler(http.FileServer(http.Dir("./static/"))) + router.PathPrefix("/").Handler(http.FileServer(UnindexedFileSystem{http.Dir("./static/")})) // Setup CSRF Protection csrfHandler := csrf.Protect([]byte(auth.GenerateSecureKey()), diff --git a/controllers/static.go b/controllers/static.go new file mode 100644 index 00000000..4cdcbd6d --- /dev/null +++ b/controllers/static.go @@ -0,0 +1,35 @@ +package controllers + +import ( + "net/http" + "strings" +) + +// UnindexedFileSystem is an implementation of a standard http.FileSystem +// without the ability to list files in the directory. +// This implementation is largely inspired by +// https://www.alexedwards.net/blog/disable-http-fileserver-directory-listings +type UnindexedFileSystem struct { + fs http.FileSystem +} + +// Open returns a file from the static directory. If the requested path ends +// with a slash, there is a check for an index.html file. If none exists, then +// an error is returned. +func (ufs UnindexedFileSystem) Open(name string) (http.File, error) { + f, err := ufs.fs.Open(name) + if err != nil { + return nil, err + } + + s, err := f.Stat() + if s.IsDir() { + index := strings.TrimSuffix(name, "/") + "/index.html" + indexFile, err := ufs.fs.Open(index) + if err != nil { + return nil, err + } + return indexFile, nil + } + return f, nil +} diff --git a/controllers/static_test.go b/controllers/static_test.go new file mode 100644 index 00000000..0f36464f --- /dev/null +++ b/controllers/static_test.go @@ -0,0 +1,81 @@ +package controllers + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" + "os" + "path/filepath" +) + +var fileContent = []byte("Hello world") + +func mustRemoveAll(dir string) { + err := os.RemoveAll(dir) + if err != nil { + panic(err) + } +} + +func createTestFile(dir, filename string) error { + return ioutil.WriteFile(filepath.Join(dir, filename), fileContent, 0644) +} + +func (s *ControllersSuite) TestGetStaticFile() { + dir, err := ioutil.TempDir("static/endpoint", "test-") + tempFolder := filepath.Base(dir) + + s.Nil(err) + defer mustRemoveAll(dir) + + err = createTestFile(dir, "foo.txt") + s.Nil(nil, err) + + resp, err := http.Get(fmt.Sprintf("%s/static/%s/foo.txt", ps.URL, tempFolder)) + s.Nil(err) + + defer resp.Body.Close() + got, err := ioutil.ReadAll(resp.Body) + s.Nil(err) + + s.Equal(bytes.Compare(fileContent, got), 0, fmt.Sprintf("Got %s", got)) +} + +func (s *ControllersSuite) TestStaticFileListing() { + dir, err := ioutil.TempDir("static/endpoint", "test-") + tempFolder := filepath.Base(dir) + + s.Nil(err) + defer mustRemoveAll(dir) + + err = createTestFile(dir, "foo.txt") + s.Nil(nil, err) + + resp, err := http.Get(fmt.Sprintf("%s/static/%s/", ps.URL, tempFolder)) + s.Nil(err) + + defer resp.Body.Close() + s.Nil(err) + s.Equal(resp.StatusCode, http.StatusNotFound) +} + +func (s *ControllersSuite) TestStaticIndex() { + dir, err := ioutil.TempDir("static/endpoint", "test-") + tempFolder := filepath.Base(dir) + + s.Nil(err) + defer mustRemoveAll(dir) + + err = createTestFile(dir, "index.html") + s.Nil(nil, err) + + resp, err := http.Get(fmt.Sprintf("%s/static/%s/", ps.URL, tempFolder)) + s.Nil(err) + + defer resp.Body.Close() + got, err := ioutil.ReadAll(resp.Body) + s.Nil(err) + + s.Equal(bytes.Compare(fileContent, got), 0, fmt.Sprintf("Got %s", got)) +}