diff options
| -rw-r--r-- | handshake.go | 54 |
1 files changed, 48 insertions, 6 deletions
diff --git a/handshake.go b/handshake.go index 0bcb0cd..d1ed57d 100644 --- a/handshake.go +++ b/handshake.go @@ -10,21 +10,47 @@ import ( "os" "slices" "strings" + "time" "git.samanthony.xyz/hose/hosts" "git.samanthony.xyz/hose/key" "git.samanthony.xyz/hose/util" ) +const ( + timeout = 1 * time.Minute + retryInterval = 500 * time.Millisecond +) + // handshake exchanges public keys with a remote host. // The user is asked to verify the received key // before it is saved in the known hosts file. func handshake(rhost string) error { util.Logf("initiating handshake with %s...", rhost) - group, _ := errgroup.WithContext(context.Background()) - group.Go(func() error { return handshakeSend(rhost) }) - group.Go(func() error { return handshakeRecv(rhost) }) - return group.Wait() + + errs := make(chan error, 2) + defer close(errs) + + group, ctx := errgroup.WithContext(context.Background()) + group.Go(func() error { + if err := handshakeSend(rhost); err != nil { + errs <- err + } + return nil + }) + group.Go(func() error { + if err := handshakeRecv(rhost); err != nil { + errs <- err + } + return nil + }) + + select { + case err := <-errs: + return err + case <-ctx.Done(): + return nil + } } // handshakeSend sends the local public key to a remote host. @@ -37,8 +63,7 @@ func handshakeSend(rhost string) error { raddr := net.JoinHostPort(rhost, port) util.Logf("connecting to %s...", raddr) - var d net.Dialer - conn, err := d.Dial(network, raddr) + conn, err := dialWithTimeout(network, raddr, timeout) if err != nil { return err } @@ -53,6 +78,23 @@ func handshakeSend(rhost string) error { return nil } +func dialWithTimeout(network, address string, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + for { + select { + case <-ctx.Done(): // timeout. + return nil, fmt.Errorf("dial %s %s: connection refused", network, address) + default: + } + conn, err := net.Dial(network, address) + if err == nil { + return conn, nil + } + time.Sleep(retryInterval) + } +} + // handshakeRecv receives the public key of a remote host. // The user is asked to verify the key before it is saved to the known hosts file. func handshakeRecv(rhost string) error { |