From e3352f481e94054ffe08494c9225d3878347b005 Mon Sep 17 00:00:00 2001 From: Jordan Wright Date: Thu, 20 Aug 2020 09:36:18 -0500 Subject: [PATCH] Implement SSRF Mitigations (#1940) Initial commit of SSRF mitigations. This fixes #1908 by creating a *net.Dialer which restricts outbound connections to only allowed IP ranges. This implementation is based on the blog post at https://www.agwa.name/blog/post/preventing_server_side_request_forgery_in_golang To keep things backwards compatible, by default we'll only block connections to 169.254.169.254, the link-local IP address commonly used in cloud environments to retrieve metadata about the running instance. For other internal addresses (e.g. localhost or RFC 1918 addresses), it's assumed that those are available to Gophish. To support more secure environments, we introduce the `allowed_internal_hosts` configuration option where an admin can set one or more IP ranges in CIDR format. If addresses are specified here, then all internal connections will be blocked except to these hosts. There are various bits about this approach I don't really like. For example, since various packages all need this functionality, I had to make the RestrictedDialer a global singleton rather than a dependency off of, say, the admin server. Additionally, since webhooks are implemented via a singleton, I had to introduce a new function, `SetTransport`. Finally, I had to make an update in the gomail package to support a custom net.Dialer. --- config/config.go | 11 +-- controllers/api/import.go | 3 + controllers/api/import_test.go | 84 ++++++++++++++++++ dialer/dialer.go | 158 +++++++++++++++++++++++++++++++++ dialer/dialer_test.go | 85 ++++++++++++++++++ go.mod | 4 +- go.sum | 6 +- gophish.go | 10 +++ imap/imap.go | 6 +- models/smtp.go | 4 +- models/smtp_test.go | 12 +++ webhook/webhook.go | 5 ++ 12 files changed, 373 insertions(+), 15 deletions(-) create mode 100644 controllers/api/import_test.go create mode 100644 dialer/dialer.go create mode 100644 dialer/dialer_test.go diff --git a/config/config.go b/config/config.go index 4e0451de..98ec6b6a 100644 --- a/config/config.go +++ b/config/config.go @@ -9,11 +9,12 @@ import ( // AdminServer represents the Admin server configuration details type AdminServer struct { - ListenURL string `json:"listen_url"` - UseTLS bool `json:"use_tls"` - CertPath string `json:"cert_path"` - KeyPath string `json:"key_path"` - CSRFKey string `json:"csrf_key"` + ListenURL string `json:"listen_url"` + UseTLS bool `json:"use_tls"` + CertPath string `json:"cert_path"` + KeyPath string `json:"key_path"` + CSRFKey string `json:"csrf_key"` + AllowedInternalHosts []string `json:"allowed_internal_hosts"` } // PhishServer represents the Phish server configuration details diff --git a/controllers/api/import.go b/controllers/api/import.go index 7cf96a01..efaf0178 100644 --- a/controllers/api/import.go +++ b/controllers/api/import.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/PuerkitoBio/goquery" + "github.com/gophish/gophish/dialer" log "github.com/gophish/gophish/logger" "github.com/gophish/gophish/models" "github.com/gophish/gophish/util" @@ -113,7 +114,9 @@ func (as *Server) ImportSite(w http.ResponseWriter, r *http.Request) { JSONResponse(w, models.Response{Success: false, Message: err.Error()}, http.StatusBadRequest) return } + restrictedDialer := dialer.Dialer() tr := &http.Transport{ + DialContext: restrictedDialer.DialContext, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, diff --git a/controllers/api/import_test.go b/controllers/api/import_test.go new file mode 100644 index 00000000..2278de50 --- /dev/null +++ b/controllers/api/import_test.go @@ -0,0 +1,84 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gophish/gophish/dialer" + "github.com/gophish/gophish/models" +) + +func makeImportRequest(ctx *testContext, allowedHosts []string, url string) *httptest.ResponseRecorder { + orig := dialer.DefaultDialer.AllowedHosts() + dialer.SetAllowedHosts(allowedHosts) + req := httptest.NewRequest(http.MethodPost, "/api/import/site", + bytes.NewBuffer([]byte(fmt.Sprintf(` + { + "url" : "%s" + } + `, url)))) + req.Header.Set("Content-Type", "application/json") + response := httptest.NewRecorder() + ctx.apiServer.ImportSite(response, req) + dialer.SetAllowedHosts(orig) + return response +} + +func TestDefaultDeniedImport(t *testing.T) { + ctx := setupTest(t) + metadataURL := "http://169.254.169.254/latest/meta-data/" + response := makeImportRequest(ctx, []string{}, metadataURL) + expectedCode := http.StatusBadRequest + if response.Code != expectedCode { + t.Fatalf("incorrect status code received. expected %d got %d", expectedCode, response.Code) + } + got := &models.Response{} + err := json.NewDecoder(response.Body).Decode(got) + if err != nil { + t.Fatalf("error decoding body: %v", err) + } + if !strings.Contains(got.Message, "upstream connection denied") { + t.Fatalf("incorrect response error provided: %s", got.Message) + } +} + +func TestDefaultAllowedImport(t *testing.T) { + ctx := setupTest(t) + h := "" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, h) + })) + defer ts.Close() + response := makeImportRequest(ctx, []string{}, ts.URL) + expectedCode := http.StatusOK + if response.Code != expectedCode { + t.Fatalf("incorrect status code received. expected %d got %d", expectedCode, response.Code) + } +} + +func TestCustomDeniedImport(t *testing.T) { + ctx := setupTest(t) + h := "" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, h) + })) + defer ts.Close() + response := makeImportRequest(ctx, []string{"192.168.1.1"}, ts.URL) + expectedCode := http.StatusBadRequest + if response.Code != expectedCode { + t.Fatalf("incorrect status code received. expected %d got %d", expectedCode, response.Code) + } + got := &models.Response{} + err := json.NewDecoder(response.Body).Decode(got) + if err != nil { + t.Fatalf("error decoding body: %v", err) + } + if !strings.Contains(got.Message, "upstream connection denied") { + t.Fatalf("incorrect response error provided: %s", got.Message) + } +} diff --git a/dialer/dialer.go b/dialer/dialer.go new file mode 100644 index 00000000..15fb402e --- /dev/null +++ b/dialer/dialer.go @@ -0,0 +1,158 @@ +package dialer + +import ( + "fmt" + "net" + "syscall" + "time" +) + +// RestrictedDialer is used to create a net.Dialer which restricts outbound +// connections to only allowlisted IP ranges. +type RestrictedDialer struct { + allowedHosts []*net.IPNet +} + +// DefaultDialer is a global instance of a RestrictedDialer +var DefaultDialer = &RestrictedDialer{} + +// SetAllowedHosts sets the list of allowed hosts or IP ranges for the default +// dialer. +func SetAllowedHosts(allowed []string) { + DefaultDialer.SetAllowedHosts(allowed) +} + +// AllowedHosts returns the configured hosts that are allowed for the dialer. +func (d *RestrictedDialer) AllowedHosts() []string { + ranges := []string{} + for _, ipRange := range d.allowedHosts { + ranges = append(ranges, ipRange.String()) + } + return ranges +} + +// SetAllowedHosts sets the list of allowed hosts or IP ranges for the dialer. +func (d *RestrictedDialer) SetAllowedHosts(allowed []string) error { + for _, ipRange := range allowed { + // For flexibility, try to parse as an IP first since this will + // undoubtedly cause issues. If it works, then just append the + // appropriate subnet mask, then parse as CIDR + if singleIP := net.ParseIP(ipRange); singleIP != nil { + if singleIP.To4() != nil { + ipRange += "/32" + } else { + ipRange += "/128" + } + } + _, parsed, err := net.ParseCIDR(ipRange) + if err != nil { + return fmt.Errorf("provided ip range is not valid CIDR notation: %v", err) + } + d.allowedHosts = append(d.allowedHosts, parsed) + } + return nil +} + +// Dialer returns a net.Dialer that restricts outbound connections to only the +// addresses allowed by the DefaultDialer. +func Dialer() *net.Dialer { + return DefaultDialer.Dialer() +} + +// Dialer returns a net.Dialer that restricts outbound connections to only the +// allowed addresses over TCP. +// +// By default, since Gophish anticipates connections originating to hosts on +// the local network, we only deny access to the link-local addresses at +// 169.254.0.0/16. +// +// If hosts are provided, then Gophish blocks access to all local addresses +// except the ones provided. +// +// This implementation is based on the blog post by Andrew Ayer at +// https://www.agwa.name/blog/post/preventing_server_side_request_forgery_in_golang +func (d *RestrictedDialer) Dialer() *net.Dialer { + return &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + Control: restrictedControl(d.allowedHosts), + } +} + +// defaultDeny represents the list of IP ranges that we want to block unless +// explicitly overriden. +var defaultDeny = []string{ + "169.254.0.0/16", // Link-local (used for VPS instance metadata) +} + +// allInternal represents all internal hosts such that the only connections +// allowed are external ones. +var allInternal = []string{ + "0.0.0.0/8", + "127.0.0.0/8", // IPv4 loopback + "10.0.0.0/8", // RFC1918 + "100.64.0.0/10", // CGNAT + "172.16.0.0/12", // RFC1918 + "169.254.0.0/16", // RFC3927 link-local + "192.88.99.0/24", // IPv6 to IPv4 Relay + "192.168.0.0/16", // RFC1918 + "198.51.100.0/24", // TEST-NET-2 + "203.0.113.0/24", // TEST-NET-3 + "224.0.0.0/4", // Multicast + "240.0.0.0/4", // Reserved + "255.255.255.255/32", // Broadcast + "::/0", // Default route + "::/128", // Unspecified address + "::1/128", // IPv6 loopback + "::ffff:0:0/96", // IPv4 mapped addresses. + "::ffff:0:0:0/96", // IPv4 translated addresses. + "fe80::/10", // IPv6 link-local + "fc00::/7", // IPv6 unique local addr +} + +type dialControl = func(network, address string, c syscall.RawConn) error + +type restrictedDialer struct { + *net.Dialer + allowed []string +} + +func restrictedControl(allowed []*net.IPNet) dialControl { + return func(network string, address string, conn syscall.RawConn) error { + if !(network == "tcp4" || network == "tcp6") { + return fmt.Errorf("%s is not a safe network type", network) + } + + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("%s is not a valid host/port pair: %s", address, err) + } + + ip := net.ParseIP(host) + if ip == nil { + return fmt.Errorf("%s is not a valid IP address", host) + } + + denyList := defaultDeny + if len(allowed) > 0 { + denyList = allInternal + } + + for _, ipRange := range allowed { + if ipRange.Contains(ip) { + return nil + } + } + + for _, ipRange := range denyList { + _, parsed, err := net.ParseCIDR(ipRange) + if err != nil { + return fmt.Errorf("error parsing denied range: %v", err) + } + if parsed.Contains(ip) { + return fmt.Errorf("upstream connection denied to internal host") + } + } + return nil + } +} diff --git a/dialer/dialer_test.go b/dialer/dialer_test.go new file mode 100644 index 00000000..0b70b1a9 --- /dev/null +++ b/dialer/dialer_test.go @@ -0,0 +1,85 @@ +package dialer + +import ( + "fmt" + "net" + "strings" + "syscall" + "testing" +) + +func TestDefaultDeny(t *testing.T) { + control := restrictedControl([]*net.IPNet{}) + host := "169.254.169.254" + expected := fmt.Errorf("upstream connection denied to internal host at %s", host) + conn := new(syscall.RawConn) + got := control("tcp4", fmt.Sprintf("%s:80", host), *conn) + if !strings.Contains(got.Error(), "upstream connection denied") { + t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got) + } +} + +func TestDefaultAllow(t *testing.T) { + control := restrictedControl([]*net.IPNet{}) + host := "1.1.1.1" + conn := new(syscall.RawConn) + got := control("tcp4", fmt.Sprintf("%s:80", host), *conn) + if got != nil { + t.Fatalf("error dialing allowed host. got %v", got) + } +} + +func TestCustomAllow(t *testing.T) { + host := "127.0.0.1" + _, ipRange, _ := net.ParseCIDR(fmt.Sprintf("%s/32", host)) + allowed := []*net.IPNet{ipRange} + control := restrictedControl(allowed) + conn := new(syscall.RawConn) + got := control("tcp4", fmt.Sprintf("%s:80", host), *conn) + if got != nil { + t.Fatalf("error dialing allowed host. got %v", got) + } +} + +func TestCustomDeny(t *testing.T) { + host := "127.0.0.1" + _, ipRange, _ := net.ParseCIDR(fmt.Sprintf("%s/32", host)) + allowed := []*net.IPNet{ipRange} + control := restrictedControl(allowed) + conn := new(syscall.RawConn) + expected := fmt.Errorf("upstream connection denied to internal host at %s", host) + got := control("tcp4", "192.168.1.2:80", *conn) + if !strings.Contains(got.Error(), "upstream connection denied") { + t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got) + } +} + +func TestSingleIP(t *testing.T) { + orig := DefaultDialer.AllowedHosts() + host := "127.0.0.1" + DefaultDialer.SetAllowedHosts([]string{host}) + control := DefaultDialer.Dialer().Control + conn := new(syscall.RawConn) + expected := fmt.Errorf("upstream connection denied to internal host at %s", host) + got := control("tcp4", "192.168.1.2:80", *conn) + if !strings.Contains(got.Error(), "upstream connection denied") { + t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got) + } + + host = "::1" + DefaultDialer.SetAllowedHosts([]string{host}) + control = DefaultDialer.Dialer().Control + conn = new(syscall.RawConn) + expected = fmt.Errorf("upstream connection denied to internal host at %s", host) + got = control("tcp4", "192.168.1.2:80", *conn) + if !strings.Contains(got.Error(), "upstream connection denied") { + t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got) + } + + // Test an allowed connection + got = control("tcp4", fmt.Sprintf("[%s]:80", host), *conn) + if got != nil { + t.Fatalf("error dialing allowed host. got %v", got) + } + DefaultDialer.SetAllowedHosts(orig) +} diff --git a/go.mod b/go.mod index e0649971..6fc6a432 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/emersion/go-imap v1.0.4 github.com/emersion/go-message v0.12.0 github.com/go-sql-driver/mysql v1.5.0 - github.com/gophish/gomail v0.0.0-20180314010319-cf7e1a5479be + github.com/gophish/gomail v0.0.0-20200818021916-1f6d0dfd512e github.com/gorilla/context v1.1.1 github.com/gorilla/csrf v1.6.2 github.com/gorilla/handlers v1.4.2 @@ -29,7 +29,5 @@ require ( golang.org/x/crypto v0.0.0-20200128174031-69ecbb4d6d5d golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 gopkg.in/alecthomas/kingpin.v2 v2.2.6 - gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 - gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df // indirect ) diff --git a/go.sum b/go.sum index ede80085..c9bae1df 100644 --- a/go.sum +++ b/go.sum @@ -32,8 +32,8 @@ github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/gophish/gomail v0.0.0-20180314010319-cf7e1a5479be h1:VTe1cdyqSi/wLowKNz/shz6E0G+9/XzldZbyAmt+0Yw= -github.com/gophish/gomail v0.0.0-20180314010319-cf7e1a5479be/go.mod h1:MpSuP7kw+gRy2z+4gIFZeF3DwhhdQhEXwRmPVQYD9ig= +github.com/gophish/gomail v0.0.0-20200818021916-1f6d0dfd512e h1:URNpXdOxXAfuZ8wsr/DY27KTffVenKDjtNVAEwcR2Oo= +github.com/gophish/gomail v0.0.0-20200818021916-1f6d0dfd512e/go.mod h1:JGlHttcLdDp3F4g8bPHqqQnUUDuB3poB4zLXozQ0xCY= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/csrf v1.6.2 h1:QqQ/OWwuFp4jMKgBFAzJVW3FMULdyUW7JoM4pEWuqKg= @@ -110,7 +110,5 @@ gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gG gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE= -gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/gophish.go b/gophish.go index 71ae7e8c..67ccab16 100644 --- a/gophish.go +++ b/gophish.go @@ -28,6 +28,7 @@ THE SOFTWARE. import ( "fmt" "io/ioutil" + "net/http" "os" "os/signal" @@ -35,10 +36,12 @@ import ( "github.com/gophish/gophish/config" "github.com/gophish/gophish/controllers" + "github.com/gophish/gophish/dialer" "github.com/gophish/gophish/imap" log "github.com/gophish/gophish/logger" "github.com/gophish/gophish/middleware" "github.com/gophish/gophish/models" + "github.com/gophish/gophish/webhook" ) const ( @@ -79,6 +82,13 @@ func main() { } config.Version = string(version) + // Configure our various upstream clients to make sure that we restrict + // outbound connections as needed. + dialer.SetAllowedHosts(conf.AdminConf.AllowedInternalHosts) + webhook.SetTransport(&http.Transport{ + DialContext: dialer.Dialer().DialContext, + }) + err = log.Setup(conf.Logging) if err != nil { log.Fatal(err) diff --git a/imap/imap.go b/imap/imap.go index 7e056176..aa76a4af 100644 --- a/imap/imap.go +++ b/imap/imap.go @@ -11,6 +11,7 @@ import ( "github.com/emersion/go-imap" "github.com/emersion/go-imap/client" "github.com/emersion/go-message/charset" + "github.com/gophish/gophish/dialer" log "github.com/gophish/gophish/logger" "github.com/gophish/gophish/models" @@ -184,12 +185,13 @@ func (mbox *Mailbox) GetUnread(markAsRead, delete bool) ([]Email, error) { func (mbox *Mailbox) newClient() (*client.Client, error) { var imapClient *client.Client var err error + restrictedDialer := dialer.Dialer() if mbox.TLS { config := new(tls.Config) config.InsecureSkipVerify = mbox.IgnoreCertErrors - imapClient, err = client.DialTLS(mbox.Host, config) + imapClient, err = client.DialWithDialerTLS(restrictedDialer, mbox.Host, config) } else { - imapClient, err = client.Dial(mbox.Host) + imapClient, err = client.DialWithDialer(restrictedDialer, mbox.Host) } if err != nil { return imapClient, err diff --git a/models/smtp.go b/models/smtp.go index 8ca8485b..cd4d4e23 100644 --- a/models/smtp.go +++ b/models/smtp.go @@ -10,6 +10,7 @@ import ( "time" "github.com/gophish/gomail" + "github.com/gophish/gophish/dialer" log "github.com/gophish/gophish/logger" "github.com/gophish/gophish/mailer" "github.com/jinzhu/gorm" @@ -109,7 +110,8 @@ func (s *SMTP) GetDialer() (mailer.Dialer, error) { log.Error(err) return nil, err } - d := gomail.NewDialer(host, port, s.Username, s.Password) + dialer := dialer.Dialer() + d := gomail.NewWithDialer(dialer, host, port, s.Username, s.Password) d.TLSConfig = &tls.Config{ ServerName: host, InsecureSkipVerify: s.IgnoreCertErrors, diff --git a/models/smtp_test.go b/models/smtp_test.go index 7ffbaadf..b559c282 100644 --- a/models/smtp_test.go +++ b/models/smtp_test.go @@ -81,3 +81,15 @@ func (s *ModelsSuite) TestGetInvalidSMTP(ch *check.C) { _, err := GetSMTP(-1, 1) ch.Assert(err, check.Equals, gorm.ErrRecordNotFound) } + +func (s *ModelsSuite) TestDefaultDeniedDial(ch *check.C) { + host := "169.254.169.254" + port := 25 + smtp := SMTP{ + Host: fmt.Sprintf("%s:%d", host, port), + } + d, err := smtp.GetDialer() + ch.Assert(err, check.Equals, nil) + _, err = d.Dial() + ch.Assert(err, check.ErrorMatches, ".*upstream connection denied.*") +} diff --git a/webhook/webhook.go b/webhook/webhook.go index 0ce281b4..92ee20bf 100644 --- a/webhook/webhook.go +++ b/webhook/webhook.go @@ -51,6 +51,11 @@ var senderInstance = &defaultSender{ }, } +// SetTransport sets the underlying transport for the default webhook client. +func SetTransport(tr *http.Transport) { + senderInstance.client.Transport = tr +} + // EndPoint represents a URL to send the webhook to, as well as a secret used // to sign the event type EndPoint struct {