From 60179461446e4ed7a1c7c764444cb7e6dfc1a579 Mon Sep 17 00:00:00 2001
From: Hidde Beydals <hello@hidde.co>
Date: Tue, 9 Jun 2020 01:11:46 +0200
Subject: [PATCH] Improve host key scanner, add Ed25519 generator

---
 cmd/tk/create_source_git.go | 67 +++++++++++++++++++++++--------------
 cmd/tk/flags.go             | 55 ++++++++++++++++++------------
 pkg/ssh/host_scan.go        | 23 ++++++-------
 pkg/ssh/key_pair.go         | 26 ++++++++++++++
 4 files changed, 111 insertions(+), 60 deletions(-)

diff --git a/cmd/tk/create_source_git.go b/cmd/tk/create_source_git.go
index 8121a19f..f017cac2 100644
--- a/cmd/tk/create_source_git.go
+++ b/cmd/tk/create_source_git.go
@@ -2,10 +2,12 @@ package main
 
 import (
 	"context"
+	"crypto/elliptic"
 	"fmt"
 	"io/ioutil"
 	"net/url"
 	"os"
+	"time"
 
 	sourcev1 "github.com/fluxcd/source-controller/api/v1alpha1"
 	"github.com/manifoldco/promptui"
@@ -42,11 +44,19 @@ For private Git repositories, the basic authentication credentials are stored in
     --url=https://github.com/stefanprodan/podinfo \
     --tag-semver=">=3.2.0 <3.3.0"
 
-  #  Create a source from a Git repository using SSH authentication
+  # Create a source from a Git repository using SSH authentication
   create source git podinfo \
     --url=ssh://git@github.com/stefanprodan/podinfo \
     --branch=master
 
+  # Create a source from a Git repository using SSH authentication and an
+  # ECDSA P-521 curve public key
+  create source git podinfo \
+    --url=ssh://git@github.com/stefanprodan/podinfo \
+    --branch=master \
+    --ssh-key-algorithm=ecdsa \
+    --ssh-ecdsa-curve=p521
+
   # Create a source from a Git repository using basic authentication
   create source git podinfo \
     --url=https://github.com/stefanprodan/podinfo \
@@ -63,9 +73,9 @@ var (
 	sourceGitSemver       string
 	sourceGitUsername     string
 	sourceGitPassword     string
-	sourceGitKeyAlgorithm PublicKeyAlgorithm
-	sourceGitRSABits      RSAKeyBits
-	sourceGitECDSACurve   ECDSACurve
+	sourceGitKeyAlgorithm PublicKeyAlgorithm = "rsa"
+	sourceGitRSABits      RSAKeyBits         = 2048
+	sourceGitECDSACurve                      = ECDSACurve{elliptic.P384()}
 )
 
 func init() {
@@ -75,9 +85,9 @@ func init() {
 	createSourceGitCmd.Flags().StringVar(&sourceGitSemver, "tag-semver", "", "git tag semver range")
 	createSourceGitCmd.Flags().StringVarP(&sourceGitUsername, "username", "u", "", "basic authentication username")
 	createSourceGitCmd.Flags().StringVarP(&sourceGitPassword, "password", "p", "", "basic authentication password")
-	createSourceGitCmd.Flags().Var(&sourceGitKeyAlgorithm, "ssh-algorithm", "SSH public key algorithm")
-	createSourceGitCmd.Flags().Var(&sourceGitRSABits, "ssh-rsa-bits", "SSH RSA public key bit size")
-	createSourceGitCmd.Flags().Var(&sourceGitECDSACurve, "ssh-ecdsa-curve", "SSH ECDSA public key curve")
+	createSourceGitCmd.Flags().Var(&sourceGitKeyAlgorithm, "ssh-key-algorithm", sourceGitKeyAlgorithm.Description())
+	createSourceGitCmd.Flags().Var(&sourceGitRSABits, "ssh-rsa-bits", sourceGitRSABits.Description())
+	createSourceGitCmd.Flags().Var(&sourceGitECDSACurve, "ssh-ecdsa-curve", sourceGitECDSACurve.Description())
 
 	createSourceCmd.AddCommand(createSourceGitCmd)
 }
@@ -108,18 +118,11 @@ func createSourceGitCmdRun(cmd *cobra.Command, args []string) error {
 
 	withAuth := false
 	if u.Scheme == "ssh" {
-		var keyGen ssh.KeyPairGenerator
-		switch sourceGitKeyAlgorithm.String() {
-		case "rsa":
-			keyGen = ssh.NewRSAGenerator(int(sourceGitRSABits))
-		case "ecdsa":
-			keyGen = ssh.NewECDSAGenerator(sourceGitECDSACurve.Curve)
-		}
 		host := u.Host
 		if u.Port() == "" {
 			host = host + ":22"
 		}
-		if err := generateSSH(ctx, keyGen, name, host, tmpDir); err != nil {
+		if err := generateSSH(ctx, name, host); err != nil {
 			return err
 		}
 		withAuth = true
@@ -212,13 +215,14 @@ func generateBasicAuth(ctx context.Context, name string) error {
 	return nil
 }
 
-func generateSSH(ctx context.Context, generator ssh.KeyPairGenerator, name, host, user string) error {
-	logGenerate("generating deploy key")
-	kp, err := generator.Generate()
+func generateSSH(ctx context.Context, name, host string) error {
+	gen := getKeyPairGenerator()
+	logGenerate("generating deploy key pair")
+	pair, err := gen.Generate()
 	if err != nil {
-		return fmt.Errorf("SSH key pair generation failed: %w", err)
+		return fmt.Errorf("key pair generation failed: %w", err)
 	}
-	fmt.Printf("%s", kp.PublicKey)
+	fmt.Printf("%s", pair.PublicKey)
 
 	prompt := promptui.Prompt{
 		Label:     "Have you added the deploy key to your repository",
@@ -228,21 +232,21 @@ func generateSSH(ctx context.Context, generator ssh.KeyPairGenerator, name, host
 		return fmt.Errorf("aborting")
 	}
 
-	logAction("collecting SSH server public key for generated public key algorithm")
-	hostKey, err := ssh.ScanHostKey(host, user, kp)
+	logAction("collecting preferred public key from SSH server")
+	hostKey, err := ssh.ScanHostKey(host, 30*time.Second)
 	if err != nil {
 		return err
 	}
-	logSuccess("collected public key from SSH server")
+	logSuccess("collected public key from SSH server:")
 	fmt.Printf("%s", hostKey)
 
 	logAction("saving keys")
 	files := fmt.Sprintf("--from-literal=identity=\"%s\" --from-literal=identity.pub=\"%s\" --from-literal=known_hosts=\"%s\"",
-		kp.PublicKey, kp.PrivateKey, hostKey)
+		pair.PrivateKey, pair.PublicKey, hostKey)
 	secret := fmt.Sprintf("kubectl -n %s create secret generic %s %s --dry-run=client -oyaml | kubectl apply -f-",
 		namespace, name, files)
 	if _, err := utils.execCommand(ctx, ModeOS, secret); err != nil {
-		return fmt.Errorf("create secret failed")
+		return fmt.Errorf("failed to create secret")
 	}
 	return nil
 }
@@ -301,3 +305,16 @@ func isGitRepositoryReady(ctx context.Context, kubeClient client.Client, name, n
 		return false, nil
 	}
 }
+
+func getKeyPairGenerator() ssh.KeyPairGenerator {
+	var keyGen ssh.KeyPairGenerator
+	switch sourceGitKeyAlgorithm.String() {
+	case "rsa":
+		keyGen = ssh.NewRSAGenerator(int(sourceGitRSABits))
+	case "ecdsa":
+		keyGen = ssh.NewECDSAGenerator(sourceGitECDSACurve.Curve)
+	case "ed25519":
+		keyGen = ssh.NewEd25519Generator()
+	}
+	return keyGen
+}
diff --git a/cmd/tk/flags.go b/cmd/tk/flags.go
index f91e729c..2ae65b9e 100644
--- a/cmd/tk/flags.go
+++ b/cmd/tk/flags.go
@@ -7,7 +7,7 @@ import (
 	"strings"
 )
 
-var supportedPublicKeyAlgorithms = []string{"rsa", "ecdsa"}
+var supportedPublicKeyAlgorithms = []string{"rsa", "ecdsa", "ed25519"}
 
 type PublicKeyAlgorithm string
 
@@ -17,8 +17,8 @@ func (a *PublicKeyAlgorithm) String() string {
 
 func (a *PublicKeyAlgorithm) Set(str string) error {
 	if strings.TrimSpace(str) == "" {
-		*a = PublicKeyAlgorithm(supportedPublicKeyAlgorithms[0])
-		return nil
+		return fmt.Errorf("no public key algorithm given, must be one of: %s",
+			strings.Join(supportedPublicKeyAlgorithms, ", "))
 	}
 	for _, v := range supportedPublicKeyAlgorithms {
 		if str == v {
@@ -26,17 +26,18 @@ func (a *PublicKeyAlgorithm) Set(str string) error {
 			return nil
 		}
 	}
-	return fmt.Errorf(
-		"unsupported public key algorithm '%s', must be one of: %s",
-		str,
-		strings.Join(supportedPublicKeyAlgorithms, ", "),
-	)
+	return fmt.Errorf("unsupported public key algorithm '%s', must be one of: %s",
+		str, strings.Join(supportedPublicKeyAlgorithms, ", "))
 }
 
 func (a *PublicKeyAlgorithm) Type() string {
 	return "publicKeyAlgorithm"
 }
 
+func (a *PublicKeyAlgorithm) Description() string {
+	return fmt.Sprintf("SSH public key algorithm (%s)", strings.Join(supportedPublicKeyAlgorithms, ", "))
+}
+
 var defaultRSAKeyBits = 2048
 
 type RSAKeyBits int
@@ -65,37 +66,47 @@ func (b *RSAKeyBits) Type() string {
 	return "rsaKeyBits"
 }
 
+func (b *RSAKeyBits) Description() string {
+	return "SSH RSA public key bit size (multiplies of 8)"
+}
+
 type ECDSACurve struct {
 	elliptic.Curve
 }
 
 var supportedECDSACurves = map[string]elliptic.Curve{
-	"P-256": elliptic.P256(),
-	"P-384": elliptic.P384(),
-	"P-521": elliptic.P521(),
+	"p256": elliptic.P256(),
+	"p384": elliptic.P384(),
+	"p521": elliptic.P521(),
 }
 
 func (c *ECDSACurve) String() string {
-	if c == nil || c.Curve == nil {
+	if c.Curve == nil {
 		return ""
 	}
-	return c.Curve.Params().Name
+	return strings.ToLower(strings.Replace(c.Curve.Params().Name, "-", "", 1))
 }
 
 func (c *ECDSACurve) Set(str string) error {
-	if strings.TrimSpace(str) == "" {
-		*c = ECDSACurve{supportedECDSACurves["P-384"]}
+	if v, ok := supportedECDSACurves[str]; ok {
+		*c = ECDSACurve{v}
 		return nil
 	}
-	for k, v := range supportedECDSACurves {
-		if k == str {
-			*c = ECDSACurve{v}
-			return nil
-		}
-	}
-	return fmt.Errorf("unsupported curve '%s', should be one of: P-256, P-384, P-521", str)
+	return fmt.Errorf("unsupported curve '%s', should be one of: %s", str, strings.Join(ecdsaCurves(), ", "))
 }
 
 func (c *ECDSACurve) Type() string {
 	return "ecdsaCurve"
 }
+
+func (c *ECDSACurve) Description() string {
+	return fmt.Sprintf("SSH ECDSA public key curve (%s)", strings.Join(ecdsaCurves(), ", "))
+}
+
+func ecdsaCurves() []string {
+	keys := make([]string, 0, len(supportedECDSACurves))
+	for k := range supportedECDSACurves {
+		keys = append(keys, k)
+	}
+	return keys
+}
diff --git a/pkg/ssh/host_scan.go b/pkg/ssh/host_scan.go
index 61844608..a7a1a20d 100644
--- a/pkg/ssh/host_scan.go
+++ b/pkg/ssh/host_scan.go
@@ -4,27 +4,20 @@ import (
 	"encoding/base64"
 	"fmt"
 	"net"
+	"time"
 
 	"golang.org/x/crypto/ssh"
 	"golang.org/x/crypto/ssh/knownhosts"
 )
 
 // ScanHostKey collects the given host's preferred public key for the
-// algorithm of the given key pair. Any errors (e.g. authentication
-// failures) are ignored, except if no key could be collected from the
-// host.
-func ScanHostKey(host string, user string, pair *KeyPair) ([]byte, error) {
-	signer, err := ssh.ParsePrivateKey(pair.PrivateKey)
-	if err != nil {
-		return nil, err
-	}
+// Any errors (e.g. authentication  failures) are ignored, except if
+// no key could be collected from the host.
+func ScanHostKey(host string, timeout time.Duration) ([]byte, error) {
 	col := &collector{}
 	config := &ssh.ClientConfig{
-		User: user,
-		Auth: []ssh.AuthMethod{
-			ssh.PublicKeys(signer),
-		},
-		HostKeyCallback: col.StoreKey(),
+		HostKeyCallback:   col.StoreKey(),
+		Timeout:           timeout,
 	}
 	client, err := ssh.Dial("tcp", host, config)
 	if err == nil {
@@ -40,6 +33,10 @@ type collector struct {
 	knownKeys []byte
 }
 
+// StoreKey stores the public key in bytes as returned by the host.
+// To collect multiple public key types from the host, multiple
+// SSH dials need with the ClientConfig HostKeyAlgorithms set to
+// the algorithm you want to collect.
 func (c *collector) StoreKey() ssh.HostKeyCallback {
 	return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
 		c.knownKeys = append(
diff --git a/pkg/ssh/key_pair.go b/pkg/ssh/key_pair.go
index bad4bf2e..9f8fcd40 100644
--- a/pkg/ssh/key_pair.go
+++ b/pkg/ssh/key_pair.go
@@ -2,6 +2,7 @@ package ssh
 
 import (
 	"crypto/ecdsa"
+	"crypto/ed25519"
 	"crypto/elliptic"
 	"crypto/rand"
 	"crypto/rsa"
@@ -79,6 +80,31 @@ func (g *ECDSAGenerator) Generate() (*KeyPair, error) {
 	}, nil
 }
 
+type Ed25519Generator struct{}
+
+func NewEd25519Generator() KeyPairGenerator {
+	return &Ed25519Generator{}
+}
+
+func (g *Ed25519Generator) Generate() (*KeyPair, error) {
+	pk, pv, err := ed25519.GenerateKey(rand.Reader)
+	if err != nil {
+		return nil, err
+	}
+	pub, err := generatePublicKey(pk)
+	if err != nil {
+		return nil, err
+	}
+	priv, err := encodePrivateKeyToPEM(pv)
+	if err != nil {
+		return nil, err
+	}
+	return &KeyPair{
+		PublicKey:  pub,
+		PrivateKey: priv,
+	}, nil
+}
+
 func generatePublicKey(pk interface{}) ([]byte, error) {
 	b, err := ssh.NewPublicKey(pk)
 	if err != nil {
-- 
GitLab