@ -18,6 +18,7 @@ import (
"strings"
"time"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage"
@ -38,7 +39,7 @@ type netTun struct {
events chan tun . Event
incomingPacket chan buffer . VectorisedView
mtu int
dnsServers [ ] net . IP
dnsServers [ ] netip . Addr
hasV4 , hasV6 bool
}
type endpoint netTun
@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
func ( e * endpoint ) AddHeader ( tcpip . LinkAddress , tcpip . LinkAddress , tcpip . NetworkProtocolNumber , * stack . PacketBuffer ) {
}
func CreateNetTUN ( localAddresses , dnsServers [ ] net . IP , mtu int ) ( tun . Device , * Net , error ) {
func CreateNetTUN ( localAddresses , dnsServers [ ] netip . Addr , mtu int ) ( tun . Device , * Net , error ) {
opts := stack . Options {
NetworkProtocols : [ ] stack . NetworkProtocolFactory { ipv4 . NewProtocol , ipv6 . NewProtocol } ,
TransportProtocols : [ ] stack . TransportProtocolFactory { tcp . NewProtocol , udp . NewProtocol } ,
@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne
return nil , nil , fmt . Errorf ( "CreateNIC: %v" , tcpipErr )
}
for _ , ip := range localAddresses {
if ip4 := ip . To4 ( ) ; ip4 != nil {
protoAddr := tcpip . ProtocolAddress {
Protocol : ipv4 . ProtocolNumber ,
AddressWithPrefix : tcpip . Address ( ip4 ) . WithPrefix ( ) ,
}
tcpipErr := dev . stack . AddProtocolAddress ( 1 , protoAddr , stack . AddressProperties { } )
if tcpipErr != nil {
return nil , nil , fmt . Errorf ( "AddProtocolAddress(%v): %v" , ip4 , tcpipErr )
}
var protoNumber tcpip . NetworkProtocolNumber
if ip . Is4 ( ) {
protoNumber = ipv4 . ProtocolNumber
} else if ip . Is6 ( ) {
protoNumber = ipv6 . ProtocolNumber
}
protoAddr := tcpip . ProtocolAddress {
Protocol : protoNumber ,
AddressWithPrefix : tcpip . Address ( ip . AsSlice ( ) ) . WithPrefix ( ) ,
}
tcpipErr := dev . stack . AddProtocolAddress ( 1 , protoAddr , stack . AddressProperties { } )
if tcpipErr != nil {
return nil , nil , fmt . Errorf ( "AddProtocolAddress(%v): %v" , ip , tcpipErr )
}
if ip . Is4 ( ) {
dev . hasV4 = true
} else {
protoAddr := tcpip . ProtocolAddress {
Protocol : ipv6 . ProtocolNumber ,
AddressWithPrefix : tcpip . Address ( ip ) . WithPrefix ( ) ,
}
tcpipErr := dev . stack . AddProtocolAddress ( 1 , protoAddr , stack . AddressProperties { } )
if tcpipErr != nil {
return nil , nil , fmt . Errorf ( "AddProtocolAddress(%v): %v" , ip , tcpipErr )
}
} else if ip . Is6 ( ) {
dev . hasV6 = true
}
}
@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) {
return tun . mtu , nil
}
func convertToFullAddr ( ip net . IP , port int ) ( tcpip . FullAddress , tcpip . NetworkProtocolNumber ) {
if ip4 := ip . To4 ( ) ; ip4 != nil {
return tcpip . FullAddress {
NIC : 1 ,
Addr : tcpip . Address ( ip4 ) ,
Port : uint16 ( port ) ,
} , ipv4 . ProtocolNumber
func convertToFullAddr ( endpoint netip . AddrPort ) ( tcpip . FullAddress , tcpip . NetworkProtocolNumber ) {
var protoNumber tcpip . NetworkProtocolNumber
if endpoint . Addr ( ) . Is4 ( ) {
protoNumber = ipv4 . ProtocolNumber
} else {
return tcpip . FullAddress {
NIC : 1 ,
Addr : tcpip . Address ( ip ) ,
Port : uint16 ( port ) ,
} , ipv6 . ProtocolNumber
protoNumber = ipv6 . ProtocolNumber
}
return tcpip . FullAddress {
NIC : 1 ,
Addr : tcpip . Address ( endpoint . Addr ( ) . AsSlice ( ) ) ,
Port : endpoint . Port ( ) ,
} , protoNumber
}
func ( net * Net ) DialContextTCPAddrPort ( ctx context . Context , addr netip . AddrPort ) ( * gonet . TCPConn , error ) {
fa , pn := convertToFullAddr ( addr )
return gonet . DialContextTCP ( ctx , net . stack , fa , pn )
}
func ( net * Net ) DialContextTCP ( ctx context . Context , addr * net . TCPAddr ) ( * gonet . TCPConn , error ) {
if addr == nil {
panic ( "todo: deal with auto addr semantics for nil addr" )
return net . DialContextTCPAddrPort ( ctx , netip . AddrPort { } )
}
fa , pn := convertToFullAddr ( addr . IP , addr . Port )
return gonet . DialContextTCP ( ctx , net . stack , fa , pn )
return net . DialContextTCPAddrPort ( ctx , netip . AddrPortFrom ( netip . AddrFromSlice ( addr . IP ) , uint16 ( addr . Port ) ) )
}
func ( net * Net ) DialTCPAddrPort ( addr netip . AddrPort ) ( * gonet . TCPConn , error ) {
fa , pn := convertToFullAddr ( addr )
return gonet . DialTCP ( net . stack , fa , pn )
}
func ( net * Net ) DialTCP ( addr * net . TCPAddr ) ( * gonet . TCPConn , error ) {
if addr == nil {
panic ( "todo: deal with auto addr semantics for nil addr" )
return net . DialTCPAddrPort ( netip . AddrPort { } )
}
fa , pn := convertToFullAddr ( addr . IP , addr . Port )
return gonet . DialTCP ( net . stack , fa , pn )
return net . DialTCPAddrPort ( netip . AddrPortFrom ( netip . AddrFromSlice ( addr . IP ) , uint16 ( addr . Port ) ) )
}
func ( net * Net ) ListenTCPAddrPort ( addr netip . AddrPort ) ( * gonet . TCPListener , error ) {
fa , pn := convertToFullAddr ( addr )
return gonet . ListenTCP ( net . stack , fa , pn )
}
func ( net * Net ) ListenTCP ( addr * net . TCPAddr ) ( * gonet . TCPListener , error ) {
if addr == nil {
panic ( "todo: deal with auto addr semantics for nil addr" )
return net . ListenTCPAddrPort ( netip . AddrPort { } )
}
fa , pn := convertToFullAddr ( addr . IP , addr . Port )
return gonet . ListenTCP ( net . stack , fa , pn )
return net . ListenTCPAddrPort ( netip . AddrPortFrom ( netip . AddrFromSlice ( addr . IP ) , uint16 ( addr . Port ) ) )
}
func ( net * Net ) DialUDP ( laddr , raddr * net . UDP Addr) ( * gonet . UDPConn , error ) {
func ( net * Net ) DialUDPAddrPort ( laddr , raddr netip . AddrPort ) ( * gonet . UDPConn , error ) {
var lfa , rfa * tcpip . FullAddress
var pn tcpip . NetworkProtocolNumber
if laddr != nil {
if laddr . IsValid ( ) || laddr . Port ( ) > 0 {
var addr tcpip . FullAddress
addr , pn = convertToFullAddr ( laddr . IP , laddr . Port )
addr , pn = convertToFullAddr ( laddr )
lfa = & addr
}
if raddr != nil {
if raddr . IsValid ( ) || raddr . Port ( ) > 0 {
var addr tcpip . FullAddress
addr , pn = convertToFullAddr ( raddr . IP , raddr . Port )
addr , pn = convertToFullAddr ( raddr )
rfa = & addr
}
return gonet . DialUDP ( net . stack , lfa , rfa , pn )
}
func ( net * Net ) DialUDP ( laddr , raddr * net . UDPAddr ) ( * gonet . UDPConn , error ) {
var la , ra netip . AddrPort
if laddr != nil {
la = netip . AddrPortFrom ( netip . AddrFromSlice ( laddr . IP ) , uint16 ( laddr . Port ) )
}
if raddr != nil {
ra = netip . AddrPortFrom ( netip . AddrFromSlice ( raddr . IP ) , uint16 ( raddr . Port ) )
}
return net . DialUDPAddrPort ( la , ra )
}
var (
errNoSuchHost = errors . New ( "no such host" )
errLameReferral = errors . New ( "lame referral" )
@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by
return p , h , nil
}
func ( tnet * Net ) exchange ( ctx context . Context , server net . IP , q dnsmessage . Question , timeout time . Duration ) ( dnsmessage . Parser , dnsmessage . Header , error ) {
func ( tnet * Net ) exchange ( ctx context . Context , server netip . Addr , q dnsmessage . Question , timeout time . Duration ) ( dnsmessage . Parser , dnsmessage . Header , error ) {
q . Class = dnsmessage . ClassINET
id , udpReq , tcpReq , err := newRequest ( q )
if err != nil {
@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest
var c net . Conn
var err error
if useUDP {
c , err = tnet . DialUDP ( nil , & net . UDPAddr { IP : server , Port : 53 } )
c , err = tnet . DialUDPAddrPort ( netip . AddrPort { } , netip . AddrPortFrom ( server , 53 ) )
} else {
c , err = tnet . DialContextTCP ( ctx , & net . TCPAddr { IP : server , Port : 53 } )
c , err = tnet . DialContextTCPAddrPort ( ctx , netip . AddrPortFrom ( server , 53 ) )
}
if err != nil {
@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
zlen = zidx
}
}
if ip := net . ParseIP ( host [ : zlen ] ) ; ip ! = nil {
return [ ] string { host [ : zlen ] } , nil
if ip , err := netip . ParseAddr ( host [ : zlen ] ) ; err = = nil {
return [ ] string { ip . String ( ) } , nil
}
if ! isDomainName ( host ) {
@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
server string
error
}
var addrsV4 , addrsV6 [ ] net . IP
var addrsV4 , addrsV6 [ ] netip . Addr
lanes := 0
if tnet . hasV4 {
lanes ++
@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
addrsV4 = append ( addrsV4 , net . IP ( a . A [ : ] ) )
addrsV4 = append ( addrsV4 , netip . AddrFrom4 ( a . A ) )
case dnsmessage . TypeAAAA :
aaaa , err := result . p . AAAAResource ( )
@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
addrsV6 = append ( addrsV6 , net . IP ( aaaa . AAAA [ : ] ) )
addrsV6 = append ( addrsV6 , netip . AddrFrom16 ( aaaa . AAAA ) )
default :
if err := result . p . SkipAnswer ( ) ; err != nil {
@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
}
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
var addrs [ ] net . IP
var addrs [ ] netip . Addr
if tnet . hasV6 {
addrs = append ( addrsV6 , addrsV4 ... )
} else {
@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
if err != nil {
return nil , & net . OpError { Op : "dial" , Err : err }
}
var addrs [ ] net . IP
var addrs [ ] netip . AddrPort
for _ , addr := range allAddr {
if strings . IndexByte ( addr , ':' ) != - 1 && acceptV6 {
addrs = append ( addrs , net . ParseIP ( addr ) )
} else if strings . IndexByte ( addr , '.' ) != - 1 && acceptV4 {
addrs = append ( addrs , net . ParseIP ( addr ) )
ip , err := netip . ParseAddr ( addr )
if err == nil && ( ( ip . Is4 ( ) && acceptV4 ) || ( ip . Is6 ( ) && acceptV6 ) ) {
addrs = append ( addrs , netip . AddrPortFrom ( ip , uint16 ( port ) ) )
}
}
if len ( addrs ) == 0 && len ( allAddr ) != 0 {
@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
var c net . Conn
if useUDP {
c , err = tnet . DialUDP ( nil , & net . UDPAddr { IP : addr , Port : port } )
c , err = tnet . DialUDPAddrPort ( netip . AddrPort { } , addr )
} else {
c , err = tnet . DialContextTCP ( dialCtx , & net . TCPA ddr{ IP : addr , Port : port } )
c , err = tnet . DialContextTCPAddrPort ( dialCtx , a ddr)
}
if err == nil {
return c , nil