diff --git a/main.go b/main.go index 2985129..2b1370e 100644 --- a/main.go +++ b/main.go @@ -26,6 +26,7 @@ import ( "os" "strconv" "strings" + "sync" "time" properties "github.com/arduino/go-properties-orderedmap" @@ -49,8 +50,11 @@ const mdnsServiceName = "_arduino._tcp" // since the last time they've been found by an mDNS query. const portsTTL = time.Second * 60 -// This is interval at which mDNS queries are made. -const discoveryInterval = time.Second * 15 +// Interval at which we check available network interfaces and call mdns.Query() +const queryInterval = time.Second * 30 + +// mdns.Query() will either exit early or timeout after this amount of time +const queryTimeout = time.Second * 15 // IP address used to check if we're connected to a local network var ipv4Addr = &net.UDPAddr{ @@ -64,6 +68,12 @@ var ipv6Addr = &net.UDPAddr{ Port: 5353, } +// QueryParam{} has to select which IP version(s) to use +type connectivity struct { + IPv4 bool + IPv6 bool +} + // MDNSDiscovery is the implementation of the network pluggable-discovery type MDNSDiscovery struct { cancelFunc func() @@ -140,55 +150,131 @@ func (d *MDNSDiscovery) StartSync(eventCB discovery.EventCallback, errorCB disco ctx, cancel := context.WithCancel(context.Background()) go func() { defer close(queriesChan) + queryLoop(ctx, queriesChan) + }() + go func() { + for entry := range queriesChan { + if d.entriesChan != nil { + d.entriesChan <- entry + } + } + }() + d.cancelFunc = cancel + return nil +} - disableIPv6 := false - // Check if the current network supports IPv6 - mconn6, err := net.ListenMulticastUDP("udp6", nil, ipv6Addr) +func queryLoop(ctx context.Context, queriesChan chan<- *mdns.ServiceEntry) { + for { + var interfaces []net.Interface + var conn connectivity + var wg sync.WaitGroup + + interfaces, err := availableInterfaces() if err != nil { - disableIPv6 = true - } else { - mconn6.Close() + goto NEXT } - // We must check if we're connected to a local network, if we don't - // the subsequent mDNS query would fail and return an error. - mconn4, err := net.ListenMulticastUDP("udp4", nil, ipv4Addr) - if err != nil { + conn = checkConnectivity() + if !conn.available() { + goto NEXT + } + + wg.Add(len(interfaces)) + + for n := range interfaces { + params := makeQueryParams(&interfaces[n], conn, queriesChan) + go func() { + defer wg.Done() + mdns.Query(params) + }() + } + + wg.Wait() + + NEXT: + select { + case <-time.After(queryInterval): + case <-ctx.Done(): return } - // If we managed to open a connection close it, mdns.Query opens - // another one on the same IP address we use and it would fail - // if we leave this open. + } +} + +func (conn *connectivity) available() bool { + return conn.IPv4 || conn.IPv6 +} + +func checkConnectivity() connectivity { + // We must check if we're connected to a local network, if we don't + // the subsequent mDNS query would fail and return an error. + // If we managed to open a connection close it, mdns.Query opens + // another one on the same IP address we use and it would fail + // if we leave this open. + out := connectivity{ + IPv4: true, + IPv6: true, + } + + // Check if the current network supports IPv6 + mconn6, err := net.ListenMulticastUDP("udp6", nil, ipv6Addr) + if err != nil { + out.IPv6 = false + } else { + mconn6.Close() + } + + // And the same for IPv4 + mconn4, err := net.ListenMulticastUDP("udp4", nil, ipv4Addr) + if err != nil { + out.IPv4 = false + } else { mconn4.Close() + } + + return out +} + +func availableInterfaces() ([]net.Interface, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } - params := &mdns.QueryParam{ - Service: mdnsServiceName, - Domain: "local", - Timeout: discoveryInterval, - Entries: queriesChan, - WantUnicastResponse: false, - DisableIPv6: disableIPv6, + var out []net.Interface + for _, netif := range interfaces { + if netif.Flags&net.FlagUp == 0 { + continue } - for { - if err := mdns.Query(params); err != nil { - errorCB("mdns lookup error: " + err.Error()) - } - select { - default: - case <-ctx.Done(): - return - } + + if netif.Flags&net.FlagMulticast == 0 { + continue } - }() - go func() { - for entry := range queriesChan { - if d.entriesChan != nil { - d.entriesChan <- entry - } + + if netif.HardwareAddr == nil { + continue } - }() - d.cancelFunc = cancel - return nil + + out = append(out, netif) + } + + if len(out) == 0 { + return nil, fmt.Errorf("no valid network interfaces") + } + + return out, nil +} + +func makeQueryParams(netif *net.Interface, conn connectivity, queriesChan chan<- *mdns.ServiceEntry) (params *mdns.QueryParam) { + return &mdns.QueryParam{ + Service: mdnsServiceName, + Domain: "local", + Timeout: queryTimeout, + Interface: netif, + Entries: queriesChan, + WantUnicastResponse: false, + DisableIPv4: !conn.IPv4, + DisableIPv6: !conn.IPv6, + } } func toDiscoveryPort(entry *mdns.ServiceEntry) *discovery.Port {