aboutsummaryrefslogtreecommitdiffstats
path: root/hosts/hosts.go
blob: 0412e4a17912f57d4116a7a630fa2b8ce5c7f5dc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package hosts

import (
	"bufio"
	"bytes"
	"errors"
	"fmt"
	"github.com/adrg/xdg"
	"net/netip"
	"os"
	"path/filepath"
	"slices"

	"git.samanthony.xyz/hose/key"
	"git.samanthony.xyz/hose/util"
)

var knownHostsFile = filepath.Join(xdg.DataHome, "hose", "known_hosts")

type Host struct {
	netip.Addr       // address.
	key.BoxPublicKey // public encryption key.
	key.SigPublicKey // public signature verification key.
}

// Add adds or replaces an entry in the known hosts file.
func Add(host Host) error {
	hosts, err := Load()
	if err != nil {
		return err
	}

	i, ok := slices.BinarySearchFunc(hosts, host, cmpHost)
	if ok {
		util.Logf("replacing host %q in known hosts file")
		hosts[i] = host
	} else {
		hosts = slices.Insert(hosts, i, host)
	}

	return Store(hosts)
}

// Load loads the set of known hosts from disc.
// The returned list is sorted.
func Load() ([]Host, error) {
	hosts := make([]Host, 0)

	f, err := os.Open(knownHostsFile)
	if errors.Is(err, os.ErrNotExist) {
		return hosts, nil // no known hosts yet.
	} else if err != nil {
		return hosts, err
	}
	defer f.Close()

	scanner := bufio.NewScanner(f)
	for line := 1; scanner.Scan(); line++ {
		host, err := parseHost(scanner.Bytes())
		if err != nil {
			return hosts, fmt.Errorf("error parsing known hosts file: %s:%d: %v", knownHostsFile, line, err)
		}
		i, ok := slices.BinarySearchFunc(hosts, host, cmpHost)
		if ok {
			return hosts, fmt.Errorf("duplicate entry in known hosts file: %s", host)
		}
		hosts = slices.Insert(hosts, i, host)
	}
	return hosts, scanner.Err()
}

// parseHost parses a line of the known hosts file.
func parseHost(b []byte) (Host, error) {
	fields := bytes.Fields(b)
	if len(fields) != 3 {
		return Host{}, fmt.Errorf("expected 3 fields; got %d", len(fields))
	}

	addr, err := netip.ParseAddr(string(fields[0]))
	if err != nil {
		return Host{}, err
	}

	boxPubKey, err := key.DecodeBoxPublicKey(fields[1])
	if err != nil {
		return Host{}, err
	}

	sigPubKey, err := key.DecodeSigPublicKey(fields[2])
	if err != nil {
		return Host{}, err
	}

	return Host{addr, boxPubKey, sigPubKey}, nil
}

// Store stores the set of known hosts to disc. It overwrites the entire file.
func Store(hosts []Host) error {
	slices.SortFunc(hosts, cmpHost)

	f, err := os.Create(knownHostsFile)
	if err != nil {
		return err
	}
	defer f.Close()

	for _, host := range hosts {
		fmt.Fprintf(f, "%s\n", host)
	}

	return nil
}

func cmpHost(a, b Host) int {
	if x := a.Addr.Compare(b.Addr); x != 0 {
		return x
	}
	if x := a.BoxPublicKey.Compare(b.BoxPublicKey); x != 0 {
		return x
	}
	return a.SigPublicKey.Compare(b.SigPublicKey)
}

func (h Host) String() string {
	return fmt.Sprintf("%s %x %x", h.Addr, h.BoxPublicKey, h.SigPublicKey)
}