aboutsummaryrefslogtreecommitdiffstats
path: root/hosts/hosts.go
blob: d31ed5e750e3793ad6ccf6426a85a6a4b71ba610 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package hosts

import (
	"bufio"
	"encoding/hex"
	"errors"
	"fmt"
	"github.com/adrg/xdg"
	"net"
	"net/netip"
	"os"
	"path/filepath"
	"strings"
)

var knownHostsFile = filepath.Join(xdg.DataHome, "hose", "known_hosts")

// Set sets the public key of a remote host.
// It replaces or creates an entry in the known hosts file.
func Set(host net.Addr, pubkey [32]byte) error {
	addr, err := netip.ParseAddr(host.String())
	if err != nil {
		return err
	}

	hosts, err := Load()
	if err != nil {
		return err
	}

	hosts[addr] = pubkey

	return Store(hosts)
}

// Load loads the set of known hosts and their associated public keys
// from disc.
func Load() (map[netip.Addr][32]byte, error) {
	hosts := make(map[netip.Addr][32]byte)

	f, err := os.Open(knownHostsFile)
	if errors.Is(err, os.ErrNotExist) {
		return hosts, nil // no known hosts yet.
	} else if err != nil {
		return hosts, err
	}
	defer f.Close()

	scanner := bufio.NewScanner(f)
	for line := 1; scanner.Scan(); line++ {
		host, pubkey, err := parseHostKeyPair(scanner.Text())
		if err != nil {
			return hosts, fmt.Errorf("error parsing known hosts file: %s:%d: %v", err)
		}
		if _, ok := hosts[host]; ok {
			return hosts, fmt.Errorf("duplicate entry in known hosts file: %s", host)
		}
		hosts[host] = pubkey
	}
	return hosts, scanner.Err()
}

// parseHostKeyPair parses a line of the known hosts file.
func parseHostKeyPair(s string) (netip.Addr, [32]byte, error) {
	fields := strings.Fields(s)
	if len(fields) != 2 {
		return netip.Addr{}, [32]byte{}, fmt.Errorf("expected 2 fields; got %d", len(fields))
	}

	addr, err := netip.ParseAddr(fields[0])
	if err != nil {
		return netip.Addr{}, [32]byte{}, err
	}

	var key [32]byte
	if hex.DecodedLen(len(fields[1])) != len(key) {
		return netip.Addr{}, [32]byte{}, fmt.Errorf("malformed key: %s", fields[1])
	}
	if _, err := hex.Decode(key[:], []byte(fields[1])); err != nil {
		return netip.Addr{}, [32]byte{}, err
	}

	return addr, key, nil
}

// Store stores the set of known hosts and their associated public keys
// to disc. It overwrites the entire file.
func Store(hosts map[netip.Addr][32]byte) error {
	f, err := os.Create(knownHostsFile)
	if err != nil {
		return err
	}
	defer f.Close()

	for host, key := range hosts {
		fmt.Fprintf(f, "%s %x", host, key)
	}

	return nil
}