From a332e1233895bbcd9ac916c9c8675f3cf6dfe835 Mon Sep 17 00:00:00 2001
From: Hidde Beydals <hello@hidde.co>
Date: Fri, 5 Jun 2020 22:22:54 +0200
Subject: [PATCH] Replace SSH shell-outs with Go implementation

---
 cmd/tk/check.go             | 18 --------
 cmd/tk/create_source_git.go | 79 ++++++++++++++++++--------------
 go.mod                      |  1 +
 internal/ssh/host_scan.go   | 47 ++++++++++++++++++++
 internal/ssh/key_pair.go    | 89 +++++++++++++++++++++++++++++++++++++
 5 files changed, 183 insertions(+), 51 deletions(-)
 create mode 100644 internal/ssh/host_scan.go
 create mode 100644 internal/ssh/key_pair.go

diff --git a/cmd/tk/check.go b/cmd/tk/check.go
index a74fee3e..29cf999b 100644
--- a/cmd/tk/check.go
+++ b/cmd/tk/check.go
@@ -45,9 +45,6 @@ func runCheckCmd(cmd *cobra.Command, args []string) error {
 
 	logAction("checking prerequisites")
 	checkFailed := false
-	if !sshCheck() {
-		checkFailed = true
-	}
 
 	if !kubectlCheck(ctx, ">=1.18.0") {
 		checkFailed = true
@@ -76,21 +73,6 @@ func runCheckCmd(cmd *cobra.Command, args []string) error {
 	return nil
 }
 
-func sshCheck() bool {
-	ok := true
-	for _, cmd := range []string{"ssh-keygen", "ssh-keyscan"} {
-		_, err := exec.LookPath(cmd)
-		if err != nil {
-			logFailure("%s not found", cmd)
-			ok = false
-		} else {
-			logSuccess("%s found", cmd)
-		}
-	}
-
-	return ok
-}
-
 func kubectlCheck(ctx context.Context, version string) bool {
 	_, err := exec.LookPath("kubectl")
 	if err != nil {
diff --git a/cmd/tk/create_source_git.go b/cmd/tk/create_source_git.go
index b5953b27..694e0842 100644
--- a/cmd/tk/create_source_git.go
+++ b/cmd/tk/create_source_git.go
@@ -2,20 +2,24 @@ package main
 
 import (
 	"context"
+	"crypto/elliptic"
 	"fmt"
+	"io/ioutil"
+	"net/url"
+	"os"
+	"strings"
+
 	sourcev1 "github.com/fluxcd/source-controller/api/v1alpha1"
 	"github.com/manifoldco/promptui"
 	"github.com/spf13/cobra"
-	"io/ioutil"
 	corev1 "k8s.io/api/core/v1"
 	"k8s.io/apimachinery/pkg/api/errors"
 	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
 	"k8s.io/apimachinery/pkg/types"
 	"k8s.io/apimachinery/pkg/util/wait"
-	"net/url"
-	"os"
 	"sigs.k8s.io/controller-runtime/pkg/client"
-	"strings"
+
+	"github.com/fluxcd/toolkit/internal/ssh"
 )
 
 var createSourceGitCmd = &cobra.Command{
@@ -55,12 +59,14 @@ For private Git repositories, the basic authentication credentials are stored in
 }
 
 var (
-	sourceGitURL      string
-	sourceGitBranch   string
-	sourceGitTag      string
-	sourceGitSemver   string
-	sourceGitUsername string
-	sourceGitPassword string
+	sourceGitURL          string
+	sourceGitBranch       string
+	sourceGitTag          string
+	sourceGitSemver       string
+	sourceGitUsername     string
+	sourceGitPassword     string
+	sourceGitKeyAlgorithm string
+	sourceGitRSABits      int
 )
 
 func init() {
@@ -70,6 +76,8 @@ func init() {
 	createSourceGitCmd.Flags().StringVar(&sourceGitSemver, "tag-semver", "", "git tag semver range")
 	createSourceGitCmd.Flags().StringVarP(&sourceGitUsername, "username", "u", "", "basic authentication username")
 	createSourceGitCmd.Flags().StringVarP(&sourceGitPassword, "password", "p", "", "basic authentication password")
+	createSourceGitCmd.Flags().StringVarP(&sourceGitKeyAlgorithm, "ssh-algorithm", "", "rsa", "SSH public key algorithm")
+	createSourceGitCmd.Flags().IntVarP(&sourceGitRSABits, "ssh-rsa-bits", "", 2048, "SSH RSA public key bit size")
 
 	createSourceCmd.AddCommand(createSourceGitCmd)
 }
@@ -99,8 +107,20 @@ func createSourceGitCmdRun(cmd *cobra.Command, args []string) error {
 	defer cancel()
 
 	withAuth := false
-	if strings.HasPrefix(sourceGitURL, "ssh") {
-		if err := generateSSH(ctx, name, u.Host, tmpDir); err != nil {
+	if u.Scheme == "ssh" {
+		var keyGen ssh.KeyPairGenerator
+		switch strings.ToLower(sourceGitKeyAlgorithm) {
+		case "rsa":
+			keyGen = ssh.NewRSAGenerator(sourceGitRSABits)
+		case "ecdsa":
+			// TODO(hidde): make curve configurable by flag
+			keyGen = ssh.NewECDSAGenerator(elliptic.P521())
+		}
+		host := u.Host
+		if u.Port() == "" {
+			host = host + ":22"
+		}
+		if err := generateSSH(ctx, keyGen, name, host, tmpDir); err != nil {
 			return err
 		}
 		withAuth = true
@@ -193,27 +213,13 @@ func generateBasicAuth(ctx context.Context, name string) error {
 	return nil
 }
 
-func generateSSH(ctx context.Context, name, host, tmpDir string) error {
-	logGenerate("generating host key for %s", host)
-
-	command := fmt.Sprintf("ssh-keyscan %s > %s/known_hosts", host, tmpDir)
-	if _, err := utils.execCommand(ctx, ModeStderrOS, command); err != nil {
-		return fmt.Errorf("ssh-keyscan failed")
-	}
-
+func generateSSH(ctx context.Context, generator ssh.KeyPairGenerator, name, host, user string) error {
 	logGenerate("generating deploy key")
-
-	command = fmt.Sprintf("ssh-keygen -b 2048 -t rsa -f %s/identity -q -N \"\"", tmpDir)
-	if _, err := utils.execCommand(ctx, ModeStderrOS, command); err != nil {
-		return fmt.Errorf("ssh-keygen failed")
-	}
-
-	command = fmt.Sprintf("cat %s/identity.pub", tmpDir)
-	if deployKey, err := utils.execCommand(ctx, ModeCapture, command); err != nil {
-		return fmt.Errorf("unable to read identity.pub: %w", err)
-	} else {
-		fmt.Print(deployKey)
+	kp, err := generator.Generate()
+	if err != nil {
+		return fmt.Errorf("SSH key pair generation failed: %w", err)
 	}
+	fmt.Printf("%s", kp.PublicKey)
 
 	prompt := promptui.Prompt{
 		Label:     "Have you added the deploy key to your repository",
@@ -223,9 +229,16 @@ func generateSSH(ctx context.Context, name, host, tmpDir string) error {
 		return fmt.Errorf("aborting")
 	}
 
+	logAction("collecting SSH server public key for generated public key algorithm")
+	serverKey, err := ssh.ScanHostKey(host, user, kp)
+	if err != nil {
+		return err
+	}
+	logSuccess("collected public key from SSH server")
+
 	logAction("saving keys")
-	files := fmt.Sprintf("--from-file=%s/identity --from-file=%s/identity.pub --from-file=%s/known_hosts",
-		tmpDir, tmpDir, tmpDir)
+	files := fmt.Sprintf("--from-literal=identity=\"%s\" --from-literal=identity.pub=\"%s\" --from-literal=known_hosts=\"%s\"",
+		kp.PublicKey, kp.PrivateKey, serverKey)
 	secret := fmt.Sprintf("kubectl -n %s create secret generic %s %s --dry-run=client -oyaml | kubectl apply -f-",
 		namespace, name, files)
 	if _, err := utils.execCommand(ctx, ModeOS, secret); err != nil {
diff --git a/go.mod b/go.mod
index 8f051ccb..1d0ca0cc 100644
--- a/go.mod
+++ b/go.mod
@@ -8,6 +8,7 @@ require (
 	github.com/fluxcd/source-controller v0.0.1-beta.1
 	github.com/manifoldco/promptui v0.7.0
 	github.com/spf13/cobra v1.0.0
+	golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073
 	k8s.io/api v0.18.2
 	k8s.io/apimachinery v0.18.2
 	k8s.io/client-go v0.18.2
diff --git a/internal/ssh/host_scan.go b/internal/ssh/host_scan.go
new file mode 100644
index 00000000..85d6720c
--- /dev/null
+++ b/internal/ssh/host_scan.go
@@ -0,0 +1,47 @@
+package ssh
+
+import (
+	"encoding/base64"
+	"fmt"
+	"net"
+
+	"golang.org/x/crypto/ssh"
+	"golang.org/x/crypto/ssh/knownhosts"
+)
+
+func ScanHostKey(host string, user string, pair *KeyPair) ([]byte, error) {
+	signer, err := ssh.ParsePrivateKey(pair.PrivateKey)
+	if err != nil {
+		return nil, err
+	}
+	col := &collector{}
+	config := &ssh.ClientConfig{
+		User: user,
+		Auth: []ssh.AuthMethod{
+			ssh.PublicKeys(signer),
+		},
+		HostKeyCallback: col.StoreKey(),
+	}
+	client, err := ssh.Dial("tcp", host, config)
+	if err == nil {
+		defer client.Close()
+	}
+	if len(col.knownKeys) > 0 {
+		return col.knownKeys, nil
+	}
+	return col.knownKeys, err
+}
+
+type collector struct {
+	knownKeys []byte
+}
+
+func (c *collector) StoreKey() ssh.HostKeyCallback {
+	return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
+		c.knownKeys = append(
+			c.knownKeys,
+			fmt.Sprintf("%s %s %s\n", knownhosts.Normalize(hostname), key.Type(), base64.StdEncoding.EncodeToString(key.Marshal()))...,
+		)
+		return nil
+	}
+}
diff --git a/internal/ssh/key_pair.go b/internal/ssh/key_pair.go
new file mode 100644
index 00000000..f251c18c
--- /dev/null
+++ b/internal/ssh/key_pair.go
@@ -0,0 +1,89 @@
+package ssh
+
+import (
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rand"
+	"crypto/rsa"
+	"crypto/x509"
+	"encoding/pem"
+
+	"golang.org/x/crypto/ssh"
+)
+
+type KeyPair struct {
+	PublicKey  []byte
+	PrivateKey []byte
+}
+
+type KeyPairGenerator interface {
+	Generate() (*KeyPair, error)
+}
+
+type RSAGenerator struct {
+	bits int
+}
+
+func NewRSAGenerator(bits int) KeyPairGenerator {
+	return &RSAGenerator{bits}
+}
+
+func (g *RSAGenerator) Generate() (*KeyPair, error) {
+	pk, err := rsa.GenerateKey(rand.Reader, g.bits)
+	if err != nil {
+		return nil, err
+	}
+	err = pk.Validate()
+	if err != nil {
+		return nil, err
+	}
+	pub, err := generatePublicKey(&pk.PublicKey)
+	if err != nil {
+		return nil, err
+	}
+	return &KeyPair{
+		PublicKey:  pub,
+		PrivateKey: encodePrivateKeyToPEM(pk),
+	}, nil
+}
+
+type ECDSAGenerator struct {
+	c elliptic.Curve
+}
+
+func NewECDSAGenerator(c elliptic.Curve) KeyPairGenerator {
+	return &ECDSAGenerator{c}
+}
+
+func (g *ECDSAGenerator) Generate() (*KeyPair, error) {
+	pk, err := ecdsa.GenerateKey(g.c, rand.Reader)
+	if err != nil {
+		return nil, err
+	}
+	pub, err := generatePublicKey(&pk.PublicKey)
+	if err != nil {
+		return nil, err
+	}
+	return &KeyPair{
+		PublicKey:  pub,
+		PrivateKey: encodePrivateKeyToPEM(pk),
+	}, nil
+}
+
+func generatePublicKey(pk interface{}) ([]byte, error) {
+	b, err := ssh.NewPublicKey(pk)
+	if err != nil {
+		return nil, err
+	}
+	k := ssh.MarshalAuthorizedKey(b)
+	return k, nil
+}
+
+func encodePrivateKeyToPEM(pk interface{}) []byte {
+	b, _ := x509.MarshalPKCS8PrivateKey(pk)
+	block := pem.Block{
+		Type:  "PRIVATE KEY",
+		Bytes: b,
+	}
+	return pem.EncodeToMemory(&block)
+}
-- 
GitLab