aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSam Anthony <sam@samanthony.xyz>2025-04-11 18:03:29 -0400
committerSam Anthony <sam@samanthony.xyz>2025-04-11 18:03:29 -0400
commitc2ccb63578460b0bf22fba3d2287365dec22014a (patch)
tree053cbde446eaf918ae56ea023687999b1eb255de
parentf558b8a0552b16a84d2b4aafb4b9cd74122522e8 (diff)
downloadhose-c2ccb63578460b0bf22fba3d2287365dec22014a.zip
retry handshake connection until timeout
-rw-r--r--handshake.go54
1 files changed, 48 insertions, 6 deletions
diff --git a/handshake.go b/handshake.go
index 0bcb0cd..d1ed57d 100644
--- a/handshake.go
+++ b/handshake.go
@@ -10,21 +10,47 @@ import (
"os"
"slices"
"strings"
+ "time"
"git.samanthony.xyz/hose/hosts"
"git.samanthony.xyz/hose/key"
"git.samanthony.xyz/hose/util"
)
+const (
+ timeout = 1 * time.Minute
+ retryInterval = 500 * time.Millisecond
+)
+
// handshake exchanges public keys with a remote host.
// The user is asked to verify the received key
// before it is saved in the known hosts file.
func handshake(rhost string) error {
util.Logf("initiating handshake with %s...", rhost)
- group, _ := errgroup.WithContext(context.Background())
- group.Go(func() error { return handshakeSend(rhost) })
- group.Go(func() error { return handshakeRecv(rhost) })
- return group.Wait()
+
+ errs := make(chan error, 2)
+ defer close(errs)
+
+ group, ctx := errgroup.WithContext(context.Background())
+ group.Go(func() error {
+ if err := handshakeSend(rhost); err != nil {
+ errs <- err
+ }
+ return nil
+ })
+ group.Go(func() error {
+ if err := handshakeRecv(rhost); err != nil {
+ errs <- err
+ }
+ return nil
+ })
+
+ select {
+ case err := <-errs:
+ return err
+ case <-ctx.Done():
+ return nil
+ }
}
// handshakeSend sends the local public key to a remote host.
@@ -37,8 +63,7 @@ func handshakeSend(rhost string) error {
raddr := net.JoinHostPort(rhost, port)
util.Logf("connecting to %s...", raddr)
- var d net.Dialer
- conn, err := d.Dial(network, raddr)
+ conn, err := dialWithTimeout(network, raddr, timeout)
if err != nil {
return err
}
@@ -53,6 +78,23 @@ func handshakeSend(rhost string) error {
return nil
}
+func dialWithTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ for {
+ select {
+ case <-ctx.Done(): // timeout.
+ return nil, fmt.Errorf("dial %s %s: connection refused", network, address)
+ default:
+ }
+ conn, err := net.Dial(network, address)
+ if err == nil {
+ return conn, nil
+ }
+ time.Sleep(retryInterval)
+ }
+}
+
// handshakeRecv receives the public key of a remote host.
// The user is asked to verify the key before it is saved to the known hosts file.
func handshakeRecv(rhost string) error {