From 6d9e7ca21637a46e643c88ff8c74e884ae908ceb Mon Sep 17 00:00:00 2001 From: Sam Anthony Date: Fri, 11 Apr 2025 15:49:46 -0400 Subject: create util package --- handshake.go | 26 ++++++++++++--------- key/generate.go | 72 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ key/keygen.go | 68 ----------------------------------------------------- main.go | 33 ++++++++++---------------- util/util.go | 16 +++++++++++++ 5 files changed, 115 insertions(+), 100 deletions(-) create mode 100644 key/generate.go delete mode 100644 key/keygen.go create mode 100644 util/util.go diff --git a/handshake.go b/handshake.go index f741ef7..aff5fa2 100644 --- a/handshake.go +++ b/handshake.go @@ -1,11 +1,14 @@ package main import ( + "context" "fmt" "golang.org/x/sync/errgroup" "io" "net" + "git.samanthony.xyz/hose/util" + "git.samanthony.xyz/hose/hosts" "git.samanthony.xyz/hose/key" ) @@ -13,8 +16,8 @@ import ( // The user is asked to verify the fingerprint of the received key // before it is saved in the known hosts file. func handshake(rhost string) error { - logf("initiating handshake with %s...", rhost) - var group errgroup.Group + util.Logf("initiating handshake with %s...", rhost) + group, _ := errgroup.WithContext(context.Background()) group.Go(func() error { return handshakeSend(rhost) }) group.Go(func() error { return handshakeRecv(rhost) }) return group.Wait() @@ -22,25 +25,26 @@ func handshake(rhost string) error { // handshakeSend sends the local public key to a remote host. func handshakeSend(rhost string) error { + util.Logf("loading public key...") pubkey, err := key.LoadPublicKey() if err != nil { return err } raddr := net.JoinHostPort(rhost, port) - logf("connecting to %s...", raddr) + util.Logf("connecting to %s...", raddr) conn, err := net.Dial(network, raddr) if err != nil { return err } defer conn.Close() - logf("connected to %s", raddr) + util.Logf("connected to %s", raddr) if _, err := conn.Write(pubkey[:]); err != nil { return err } - logf("sent public key to %s", rhost) + util.Logf("sent public key to %s", rhost) return nil } @@ -55,14 +59,14 @@ func handshakeRecv(rhost string) error { return err } defer ln.Close() - logf("listening on %s", laddr) + util.Logf("listening on %s", laddr) conn, err := ln.Accept() if err != nil { return err } defer conn.Close() - logf("accepted connection from %s", conn.RemoteAddr()) + util.Logf("accepted connection from %s", conn.RemoteAddr()) // Receive public key from remote host. var rpubkey [32]byte @@ -70,7 +74,7 @@ func handshakeRecv(rhost string) error { if err != nil { return err } - logf("received public key from $s", conn.RemoteAddr()) + util.Logf("received public key from $s", conn.RemoteAddr()) // Ask user to verify the fingerprint of the key. ok, err := verifyPublicKey(conn.RemoteAddr(), rpubkey) @@ -82,7 +86,7 @@ func handshakeRecv(rhost string) error { return fmt.Errorf("host key verification failed") } - return addKnownHost(conn.RemoteAddr(), rpubkey) + return hosts.Set(conn.RemoteAddr(), rpubkey) } // verifyPublicKey asks the user to verify the fingerprint of a public key belonging to a remote host. @@ -95,7 +99,7 @@ func verifyPublicKey(addr net.Addr, pubkey [32]byte) (bool, error) { } // Ask host to verify fingerprint. - logf("Fingerprint of host %q: %x\nIs this the correct fingerprint (yes/[no])?", + util.Logf("Fingerprint of host %q: %x\nIs this the correct fingerprint (yes/[no])?", hostname, fingerprint(pubkey[:])) var response string n, err := fmt.Scanln(&response) @@ -103,7 +107,7 @@ func verifyPublicKey(addr net.Addr, pubkey [32]byte) (bool, error) { return false, err } for n > 0 && response != "yes" && response != "no" { - logf("Please type 'yes' or 'no'") + util.Logf("Please type 'yes' or 'no'") n, err = fmt.Scanln(&response) if err != nil { return false, err diff --git a/key/generate.go b/key/generate.go new file mode 100644 index 0000000..01ae33e --- /dev/null +++ b/key/generate.go @@ -0,0 +1,72 @@ +package key + +import ( + crypto_rand "crypto/rand" + "fmt" + "golang.org/x/crypto/nacl/box" + "os" + + "git.samanthony.xyz/hose/util" +) + +// Generate generates a new public/private keypair. It stores the private key in the +// private key file and the public key in the public key file. If either of the key +// files already exist, they will not be overwritten; instead an error will be returned. +func Generate() error { + util.Logf("generating new keypair...") + + // Create public key file. + pubFile, err := createFile(pubKeyFile, pubKeyFileMode) + if err != nil { + return err + } + defer pubFile.Close() + + // Create private key file. + privFile, err := createFile(privKeyFile, privKeyFileMode) + if err != nil { + pubFile.Close() + _ = os.Remove(pubKeyFile) + return err + } + defer privFile.Close() + + // Generate keypair. + pubkey, privkey, err := box.GenerateKey(crypto_rand.Reader) + if err != nil { + return err + } + + // Write keypair to files. + if _, err := pubFile.Write((*pubkey)[:]); err != nil { + return err + } + if _, err := privFile.Write((*privkey)[:]); err != nil { + return err + } + + return nil +} + +// Generate a keypair if it doesn't already exist. +func generateIfNoExist() error { + pubExists, err := fileExists(pubKeyFile) + if err != nil { + return err + } + privExists, err := fileExists(privKeyFile) + if err != nil { + return err + } + + if pubExists && privExists { + // Keypair already exists. + return nil + } else if pubExists && !privExists { + return fmt.Errorf("found public key file but not private key file") + } else if privExists && !pubExists { + return fmt.Errorf("found private key file but not public key file") + } + // Neither public nor private key file exists; generate new keypair. + return Generate() +} diff --git a/key/keygen.go b/key/keygen.go deleted file mode 100644 index 813df31..0000000 --- a/key/keygen.go +++ /dev/null @@ -1,68 +0,0 @@ -package key - -import ( - crypto_rand "crypto/rand" - "fmt" - "golang.org/x/crypto/nacl/box" - "os" -) - -// Generate generates a new public/private keypair. It stores the private key in the -// private key file and the public key in the public key file. If either of the key -// files already exist, they will not be overwritten; instead an error will be returned. -func Generate() error { - // Create public key file. - pubFile, err := createFile(pubKeyFile, pubKeyFileMode) - if err != nil { - return err - } - defer pubFile.Close() - - // Create private key file. - privFile, err := createFile(privKeyFile, privKeyFileMode) - if err != nil { - pubFile.Close() - _ = os.Remove(pubKeyFile) - return err - } - defer privFile.Close() - - // Generate keypair. - pubkey, privkey, err := box.GenerateKey(crypto_rand.Reader) - if err != nil { - return err - } - - // Write keypair to files. - if _, err := pubFile.Write((*pubkey)[:]); err != nil { - return err - } - if _, err := privFile.Write((*privkey)[:]); err != nil { - return err - } - - return nil -} - -// Generate a keypair if it doesn't already exist. -func generateIfNoExist() error { - pubExists, err := fileExists(pubKeyFile) - if err != nil { - return err - } - privExists, err := fileExists(privKeyFile) - if err != nil { - return err - } - - if pubExists && privExists { - // Keypair already exists. - return nil - } else if pubExists && !privExists { - return fmt.Errorf("found public key file but not private key file") - } else if privExists && !pubExists { - return fmt.Errorf("found private key file but not public key file") - } - // Neither public nor private key file exists; generate new keypair. - return Generate() -} diff --git a/main.go b/main.go index 66ea5bd..161a21b 100644 --- a/main.go +++ b/main.go @@ -2,11 +2,12 @@ package main import ( "flag" - "fmt" "github.com/tonistiigi/units" "io" "net" "os" + + "git.samanthony.xyz/hose/util" ) const ( @@ -25,18 +26,18 @@ func main() { flag.Parse() if *handshakeHost != "" { if err := handshake(*handshakeHost); err != nil { - eprintf("%v\n", err) + util.Eprintf("%v\n", err) } } else if *recvFlag { if err := recv(); err != nil { - eprintf("%v\n", err) + util.Eprintf("%v\n", err) } } else if *sendHost != "" { if err := send(*sendHost); err != nil { - eprintf("%v\n", err) + util.Eprintf("%v\n", err) } } else { - logf("%s", usage) + util.Logf("%s", usage) flag.Usage() os.Exit(1) } @@ -50,42 +51,32 @@ func recv() error { return err } defer ln.Close() - logf("listening on %s", laddr) + util.Logf("listening on %s", laddr) conn, err := ln.Accept() if err != nil { return err } defer conn.Close() - logf("accepted connection from %s", conn.RemoteAddr()) + util.Logf("accepted connection from %s", conn.RemoteAddr()) n, err := io.Copy(os.Stdout, conn) - logf("received %.2f", units.Bytes(n)*units.B) + util.Logf("received %.2f", units.Bytes(n)*units.B) return err } // send pipes data from stdin to the remote host. func send(rhost string) error { raddr := net.JoinHostPort(rhost, port) - logf("connecting to %s...", raddr) + util.Logf("connecting to %s...", raddr) conn, err := net.Dial(network, raddr) if err != nil { return err } defer conn.Close() - logf("connected to %s", raddr) + util.Logf("connected to %s", raddr) n, err := io.Copy(conn, os.Stdin) - logf("sent %.2f", units.Bytes(n)*units.B) + util.Logf("sent %.2f", units.Bytes(n)*units.B) return err } - -func eprintf(format string, a ...any) { - logf(format, a...) - os.Exit(1) -} - -func logf(format string, a ...any) { - msg := fmt.Sprintf(format, a...) - fmt.Fprintf(os.Stderr, "%s\n", msg) -} diff --git a/util/util.go b/util/util.go new file mode 100644 index 0000000..0365331 --- /dev/null +++ b/util/util.go @@ -0,0 +1,16 @@ +package util + +import ( + "fmt" + "os" +) + +func Eprintf(format string, a ...any) { + Logf(format, a...) + os.Exit(1) +} + +func Logf(format string, a ...any) { + msg := fmt.Sprintf(format, a...) + fmt.Fprintf(os.Stderr, "%s\n", msg) +} -- cgit v1.2.3