gophish/dialer/dialer_test.go

86 lines
2.7 KiB
Go

package dialer
import (
"fmt"
"net"
"strings"
"syscall"
"testing"
)
func TestDefaultDeny(t *testing.T) {
control := restrictedControl([]*net.IPNet{})
host := "169.254.169.254"
expected := fmt.Errorf("upstream connection denied to internal host at %s", host)
conn := new(syscall.RawConn)
got := control("tcp4", fmt.Sprintf("%s:80", host), *conn)
if !strings.Contains(got.Error(), "upstream connection denied") {
t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got)
}
}
func TestDefaultAllow(t *testing.T) {
control := restrictedControl([]*net.IPNet{})
host := "1.1.1.1"
conn := new(syscall.RawConn)
got := control("tcp4", fmt.Sprintf("%s:80", host), *conn)
if got != nil {
t.Fatalf("error dialing allowed host. got %v", got)
}
}
func TestCustomAllow(t *testing.T) {
host := "127.0.0.1"
_, ipRange, _ := net.ParseCIDR(fmt.Sprintf("%s/32", host))
allowed := []*net.IPNet{ipRange}
control := restrictedControl(allowed)
conn := new(syscall.RawConn)
got := control("tcp4", fmt.Sprintf("%s:80", host), *conn)
if got != nil {
t.Fatalf("error dialing allowed host. got %v", got)
}
}
func TestCustomDeny(t *testing.T) {
host := "127.0.0.1"
_, ipRange, _ := net.ParseCIDR(fmt.Sprintf("%s/32", host))
allowed := []*net.IPNet{ipRange}
control := restrictedControl(allowed)
conn := new(syscall.RawConn)
expected := fmt.Errorf("upstream connection denied to internal host at %s", host)
got := control("tcp4", "192.168.1.2:80", *conn)
if !strings.Contains(got.Error(), "upstream connection denied") {
t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got)
}
}
func TestSingleIP(t *testing.T) {
orig := DefaultDialer.AllowedHosts()
host := "127.0.0.1"
DefaultDialer.SetAllowedHosts([]string{host})
control := DefaultDialer.Dialer().Control
conn := new(syscall.RawConn)
expected := fmt.Errorf("upstream connection denied to internal host at %s", host)
got := control("tcp4", "192.168.1.2:80", *conn)
if !strings.Contains(got.Error(), "upstream connection denied") {
t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got)
}
host = "::1"
DefaultDialer.SetAllowedHosts([]string{host})
control = DefaultDialer.Dialer().Control
conn = new(syscall.RawConn)
expected = fmt.Errorf("upstream connection denied to internal host at %s", host)
got = control("tcp4", "192.168.1.2:80", *conn)
if !strings.Contains(got.Error(), "upstream connection denied") {
t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got)
}
// Test an allowed connection
got = control("tcp4", fmt.Sprintf("[%s]:80", host), *conn)
if got != nil {
t.Fatalf("error dialing allowed host. got %v", got)
}
DefaultDialer.SetAllowedHosts(orig)
}