aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--handshake.go26
-rw-r--r--key/generate.go (renamed from key/keygen.go)4
-rw-r--r--main.go33
-rw-r--r--util/util.go16
4 files changed, 47 insertions, 32 deletions
diff --git a/handshake.go b/handshake.go
index f741ef7..aff5fa2 100644
--- a/handshake.go
+++ b/handshake.go
@@ -1,11 +1,14 @@
package main
import (
+ "context"
"fmt"
"golang.org/x/sync/errgroup"
"io"
"net"
+ "git.samanthony.xyz/hose/util"
+ "git.samanthony.xyz/hose/hosts"
"git.samanthony.xyz/hose/key"
)
@@ -13,8 +16,8 @@ import (
// The user is asked to verify the fingerprint of the received key
// before it is saved in the known hosts file.
func handshake(rhost string) error {
- logf("initiating handshake with %s...", rhost)
- var group errgroup.Group
+ util.Logf("initiating handshake with %s...", rhost)
+ group, _ := errgroup.WithContext(context.Background())
group.Go(func() error { return handshakeSend(rhost) })
group.Go(func() error { return handshakeRecv(rhost) })
return group.Wait()
@@ -22,25 +25,26 @@ func handshake(rhost string) error {
// handshakeSend sends the local public key to a remote host.
func handshakeSend(rhost string) error {
+ util.Logf("loading public key...")
pubkey, err := key.LoadPublicKey()
if err != nil {
return err
}
raddr := net.JoinHostPort(rhost, port)
- logf("connecting to %s...", raddr)
+ util.Logf("connecting to %s...", raddr)
conn, err := net.Dial(network, raddr)
if err != nil {
return err
}
defer conn.Close()
- logf("connected to %s", raddr)
+ util.Logf("connected to %s", raddr)
if _, err := conn.Write(pubkey[:]); err != nil {
return err
}
- logf("sent public key to %s", rhost)
+ util.Logf("sent public key to %s", rhost)
return nil
}
@@ -55,14 +59,14 @@ func handshakeRecv(rhost string) error {
return err
}
defer ln.Close()
- logf("listening on %s", laddr)
+ util.Logf("listening on %s", laddr)
conn, err := ln.Accept()
if err != nil {
return err
}
defer conn.Close()
- logf("accepted connection from %s", conn.RemoteAddr())
+ util.Logf("accepted connection from %s", conn.RemoteAddr())
// Receive public key from remote host.
var rpubkey [32]byte
@@ -70,7 +74,7 @@ func handshakeRecv(rhost string) error {
if err != nil {
return err
}
- logf("received public key from $s", conn.RemoteAddr())
+ util.Logf("received public key from $s", conn.RemoteAddr())
// Ask user to verify the fingerprint of the key.
ok, err := verifyPublicKey(conn.RemoteAddr(), rpubkey)
@@ -82,7 +86,7 @@ func handshakeRecv(rhost string) error {
return fmt.Errorf("host key verification failed")
}
- return addKnownHost(conn.RemoteAddr(), rpubkey)
+ return hosts.Set(conn.RemoteAddr(), rpubkey)
}
// verifyPublicKey asks the user to verify the fingerprint of a public key belonging to a remote host.
@@ -95,7 +99,7 @@ func verifyPublicKey(addr net.Addr, pubkey [32]byte) (bool, error) {
}
// Ask host to verify fingerprint.
- logf("Fingerprint of host %q: %x\nIs this the correct fingerprint (yes/[no])?",
+ util.Logf("Fingerprint of host %q: %x\nIs this the correct fingerprint (yes/[no])?",
hostname, fingerprint(pubkey[:]))
var response string
n, err := fmt.Scanln(&response)
@@ -103,7 +107,7 @@ func verifyPublicKey(addr net.Addr, pubkey [32]byte) (bool, error) {
return false, err
}
for n > 0 && response != "yes" && response != "no" {
- logf("Please type 'yes' or 'no'")
+ util.Logf("Please type 'yes' or 'no'")
n, err = fmt.Scanln(&response)
if err != nil {
return false, err
diff --git a/key/keygen.go b/key/generate.go
index 813df31..01ae33e 100644
--- a/key/keygen.go
+++ b/key/generate.go
@@ -5,12 +5,16 @@ import (
"fmt"
"golang.org/x/crypto/nacl/box"
"os"
+
+ "git.samanthony.xyz/hose/util"
)
// Generate generates a new public/private keypair. It stores the private key in the
// private key file and the public key in the public key file. If either of the key
// files already exist, they will not be overwritten; instead an error will be returned.
func Generate() error {
+ util.Logf("generating new keypair...")
+
// Create public key file.
pubFile, err := createFile(pubKeyFile, pubKeyFileMode)
if err != nil {
diff --git a/main.go b/main.go
index 66ea5bd..161a21b 100644
--- a/main.go
+++ b/main.go
@@ -2,11 +2,12 @@ package main
import (
"flag"
- "fmt"
"github.com/tonistiigi/units"
"io"
"net"
"os"
+
+ "git.samanthony.xyz/hose/util"
)
const (
@@ -25,18 +26,18 @@ func main() {
flag.Parse()
if *handshakeHost != "" {
if err := handshake(*handshakeHost); err != nil {
- eprintf("%v\n", err)
+ util.Eprintf("%v\n", err)
}
} else if *recvFlag {
if err := recv(); err != nil {
- eprintf("%v\n", err)
+ util.Eprintf("%v\n", err)
}
} else if *sendHost != "" {
if err := send(*sendHost); err != nil {
- eprintf("%v\n", err)
+ util.Eprintf("%v\n", err)
}
} else {
- logf("%s", usage)
+ util.Logf("%s", usage)
flag.Usage()
os.Exit(1)
}
@@ -50,42 +51,32 @@ func recv() error {
return err
}
defer ln.Close()
- logf("listening on %s", laddr)
+ util.Logf("listening on %s", laddr)
conn, err := ln.Accept()
if err != nil {
return err
}
defer conn.Close()
- logf("accepted connection from %s", conn.RemoteAddr())
+ util.Logf("accepted connection from %s", conn.RemoteAddr())
n, err := io.Copy(os.Stdout, conn)
- logf("received %.2f", units.Bytes(n)*units.B)
+ util.Logf("received %.2f", units.Bytes(n)*units.B)
return err
}
// send pipes data from stdin to the remote host.
func send(rhost string) error {
raddr := net.JoinHostPort(rhost, port)
- logf("connecting to %s...", raddr)
+ util.Logf("connecting to %s...", raddr)
conn, err := net.Dial(network, raddr)
if err != nil {
return err
}
defer conn.Close()
- logf("connected to %s", raddr)
+ util.Logf("connected to %s", raddr)
n, err := io.Copy(conn, os.Stdin)
- logf("sent %.2f", units.Bytes(n)*units.B)
+ util.Logf("sent %.2f", units.Bytes(n)*units.B)
return err
}
-
-func eprintf(format string, a ...any) {
- logf(format, a...)
- os.Exit(1)
-}
-
-func logf(format string, a ...any) {
- msg := fmt.Sprintf(format, a...)
- fmt.Fprintf(os.Stderr, "%s\n", msg)
-}
diff --git a/util/util.go b/util/util.go
new file mode 100644
index 0000000..0365331
--- /dev/null
+++ b/util/util.go
@@ -0,0 +1,16 @@
+package util
+
+import (
+ "fmt"
+ "os"
+)
+
+func Eprintf(format string, a ...any) {
+ Logf(format, a...)
+ os.Exit(1)
+}
+
+func Logf(format string, a ...any) {
+ msg := fmt.Sprintf(format, a...)
+ fmt.Fprintf(os.Stderr, "%s\n", msg)
+}