mirror of https://github.com/gophish/gophish
86 lines
2.7 KiB
Go
86 lines
2.7 KiB
Go
|
package dialer
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"strings"
|
||
|
"syscall"
|
||
|
"testing"
|
||
|
)
|
||
|
|
||
|
func TestDefaultDeny(t *testing.T) {
|
||
|
control := restrictedControl([]*net.IPNet{})
|
||
|
host := "169.254.169.254"
|
||
|
expected := fmt.Errorf("upstream connection denied to internal host at %s", host)
|
||
|
conn := new(syscall.RawConn)
|
||
|
got := control("tcp4", fmt.Sprintf("%s:80", host), *conn)
|
||
|
if !strings.Contains(got.Error(), "upstream connection denied") {
|
||
|
t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestDefaultAllow(t *testing.T) {
|
||
|
control := restrictedControl([]*net.IPNet{})
|
||
|
host := "1.1.1.1"
|
||
|
conn := new(syscall.RawConn)
|
||
|
got := control("tcp4", fmt.Sprintf("%s:80", host), *conn)
|
||
|
if got != nil {
|
||
|
t.Fatalf("error dialing allowed host. got %v", got)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestCustomAllow(t *testing.T) {
|
||
|
host := "127.0.0.1"
|
||
|
_, ipRange, _ := net.ParseCIDR(fmt.Sprintf("%s/32", host))
|
||
|
allowed := []*net.IPNet{ipRange}
|
||
|
control := restrictedControl(allowed)
|
||
|
conn := new(syscall.RawConn)
|
||
|
got := control("tcp4", fmt.Sprintf("%s:80", host), *conn)
|
||
|
if got != nil {
|
||
|
t.Fatalf("error dialing allowed host. got %v", got)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestCustomDeny(t *testing.T) {
|
||
|
host := "127.0.0.1"
|
||
|
_, ipRange, _ := net.ParseCIDR(fmt.Sprintf("%s/32", host))
|
||
|
allowed := []*net.IPNet{ipRange}
|
||
|
control := restrictedControl(allowed)
|
||
|
conn := new(syscall.RawConn)
|
||
|
expected := fmt.Errorf("upstream connection denied to internal host at %s", host)
|
||
|
got := control("tcp4", "192.168.1.2:80", *conn)
|
||
|
if !strings.Contains(got.Error(), "upstream connection denied") {
|
||
|
t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestSingleIP(t *testing.T) {
|
||
|
orig := DefaultDialer.AllowedHosts()
|
||
|
host := "127.0.0.1"
|
||
|
DefaultDialer.SetAllowedHosts([]string{host})
|
||
|
control := DefaultDialer.Dialer().Control
|
||
|
conn := new(syscall.RawConn)
|
||
|
expected := fmt.Errorf("upstream connection denied to internal host at %s", host)
|
||
|
got := control("tcp4", "192.168.1.2:80", *conn)
|
||
|
if !strings.Contains(got.Error(), "upstream connection denied") {
|
||
|
t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got)
|
||
|
}
|
||
|
|
||
|
host = "::1"
|
||
|
DefaultDialer.SetAllowedHosts([]string{host})
|
||
|
control = DefaultDialer.Dialer().Control
|
||
|
conn = new(syscall.RawConn)
|
||
|
expected = fmt.Errorf("upstream connection denied to internal host at %s", host)
|
||
|
got = control("tcp4", "192.168.1.2:80", *conn)
|
||
|
if !strings.Contains(got.Error(), "upstream connection denied") {
|
||
|
t.Fatalf("unexpected error dialing denylisted host. expected %v got %v", expected, got)
|
||
|
}
|
||
|
|
||
|
// Test an allowed connection
|
||
|
got = control("tcp4", fmt.Sprintf("[%s]:80", host), *conn)
|
||
|
if got != nil {
|
||
|
t.Fatalf("error dialing allowed host. got %v", got)
|
||
|
}
|
||
|
DefaultDialer.SetAllowedHosts(orig)
|
||
|
}
|