aboutsummaryrefslogtreecommitdiffstats
path: root/handshake/receive.go
blob: 986b96730f551f98ee3620b4deda4e5a84d89d7f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package handshake

import (
	"bufio"
	"errors"
	"io"
	"net"
	"net/netip"
	"os"
	"slices"
	"strings"

	"git.samanthony.xyz/hose/hosts"
	"git.samanthony.xyz/hose/key"
	"git.samanthony.xyz/hose/util"
)

type keyType string

const (
	boxPublicKey keyType = "Public encryption key"
	sigPublicKey         = "Public signature verification key"
)

var errVerifyKey = errors.New("host key verification failed")

// receive receives the public keys of a remote host.
// The user is asked to verify the keys before they are saved to the known hosts file.
func receive(rhost string) error {
	conn, err := acceptConnection()
	if err != nil {
		return err
	}
	defer conn.Close()
	util.Logf("accepted connection from %s", conn.RemoteAddr())

	rBoxPubKey, rSigPubKey, err := receiveKeys(conn)
	if err != nil {
		return err
	}

	// Ask user to verify the keys.
	host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
	if err != nil {
		return err
	}
	raddr, err := netip.ParseAddr(host)
	if err != nil {
		return err
	}
	if err := verifyKeys(raddr, rBoxPubKey, rSigPubKey); err != nil {
		return err
	}

	// Save in known hosts file.
	return hosts.Add(hosts.Host{raddr, rBoxPubKey, rSigPubKey})
}

func acceptConnection() (net.Conn, error) {
	laddr := net.JoinHostPort("", port)
	ln, err := net.Listen(network, laddr)
	if err != nil {
		return nil, err
	}
	defer ln.Close()
	util.Logf("listening on %s", laddr)
	return ln.Accept()
}

func receiveKeys(conn net.Conn) (key.BoxPublicKey, key.SigPublicKey, error) {
	// Receive public box (encryption) key from remote host.
	var rBoxPubKey key.BoxPublicKey
	_, err := io.ReadFull(conn, rBoxPubKey[:])
	if err != nil {
		return key.BoxPublicKey{}, key.SigPublicKey{}, err
	}
	util.Logf("received public encryption key from %s", conn.RemoteAddr())

	// Receive public signature verification key from remote host.
	var rSigPubKey key.SigPublicKey
	_, err = io.ReadFull(conn, rSigPubKey[:])
	if err != nil {
		return key.BoxPublicKey{}, key.SigPublicKey{}, err
	}
	util.Logf("receive public signature verification key from %s", conn.RemoteAddr())

	return rBoxPubKey, rSigPubKey, nil
}

// verifyKeys asks the user to verify keys received from a remote host.
// It returns a non-nil error if the user rejects the keys.
func verifyKeys(host netip.Addr, rBoxPubKey key.BoxPublicKey, rSigPubKey key.SigPublicKey) error {
	// Verify box key.
	if err := verifyKey(host, rBoxPubKey[:], boxPublicKey); err != nil {
		return err
	}
	// Verify signature verification key.
	return verifyKey(host, rSigPubKey[:], sigPublicKey)
}

// verifyKey asks the user to verify a key received from a remote host.
// It returns a non-nil error if the user rejects the key.
func verifyKey(host netip.Addr, key []byte, kt keyType) error {
	// Ask host to verify the key.
	util.Logf("%s key of host %q: %x\nIs this the correct key (yes/[no])?",
		kt, host, key[:])
	response, err := scan([]string{"yes", "no", ""})
	if err != nil {
		return err
	}
	switch response {
	case "yes":
		return nil
	case "no":
		return errVerifyKey
	case "":
		return errVerifyKey // default option
	}
	panic("unreachable")
}

// scan reads from stdin until the user enters one of the valid responses.
func scan(responses []string) (string, error) {
	scanner := bufio.NewScanner(os.Stdin)
	scanner.Scan()
	if err := scanner.Err(); err != nil {
		return "", err
	}
	response := strings.TrimSpace(scanner.Text())
	for !slices.Contains(responses, response) {
		util.Logf("Please enter one of %q", responses)
		scanner.Scan()
		if err := scanner.Err(); err != nil {
			return "", err
		}
		response = strings.TrimSpace(scanner.Text())
	}
	return response, nil
}