From eef49500a532ac8afee36a7ae5e0e970f408b1a4 Mon Sep 17 00:00:00 2001
From: cosimomeli <cosimomeli@users.noreply.github.com>
Date: Thu, 19 Dec 2024 12:32:09 +0100
Subject: [PATCH] Add support for EBS CSI Driver (#2677)

* Add support for EBS CSI Driver
---
 .gitignore                                    |   2 +
 .../templates/clusterrole.yaml                |   2 +-
 pkg/cluster/volumes.go                        |   2 +-
 pkg/cluster/volumes_test.go                   |  24 ++++
 pkg/util/constants/aws.go                     |   1 +
 pkg/util/volumes/ebs.go                       |  10 +-
 pkg/util/volumes/ebs_test.go                  | 123 ++++++++++++++++++
 7 files changed, 160 insertions(+), 4 deletions(-)
 create mode 100644 pkg/util/volumes/ebs_test.go

diff --git a/.gitignore b/.gitignore
index 66a8103d..5938db21 100644
--- a/.gitignore
+++ b/.gitignore
@@ -104,3 +104,5 @@ e2e/tls
 mocks
 
 ui/.npm/
+
+.DS_Store
diff --git a/charts/postgres-operator/templates/clusterrole.yaml b/charts/postgres-operator/templates/clusterrole.yaml
index 1fd066fa..ad3b4606 100644
--- a/charts/postgres-operator/templates/clusterrole.yaml
+++ b/charts/postgres-operator/templates/clusterrole.yaml
@@ -141,7 +141,7 @@ rules:
   - get
   - list
   - patch
-{{- if toString .Values.configKubernetes.storage_resize_mode | eq "pvc" }}
+{{- if or (toString .Values.configKubernetes.storage_resize_mode | eq "pvc") (toString .Values.configKubernetes.storage_resize_mode | eq "mixed") }}
   - update
 {{- end }}
  # to read existing PVs. Creation should be done via dynamic provisioning
diff --git a/pkg/cluster/volumes.go b/pkg/cluster/volumes.go
index 165c6c7a..240220cc 100644
--- a/pkg/cluster/volumes.go
+++ b/pkg/cluster/volumes.go
@@ -151,7 +151,7 @@ func (c *Cluster) populateVolumeMetaData() error {
 	volumeIds := []string{}
 	var volumeID string
 	for _, pv := range pvs {
-		volumeID, err = c.VolumeResizer.ExtractVolumeID(pv.Spec.AWSElasticBlockStore.VolumeID)
+		volumeID, err = c.VolumeResizer.GetProviderVolumeID(pv)
 		if err != nil {
 			continue
 		}
diff --git a/pkg/cluster/volumes_test.go b/pkg/cluster/volumes_test.go
index 99780982..95ecc762 100644
--- a/pkg/cluster/volumes_test.go
+++ b/pkg/cluster/volumes_test.go
@@ -216,6 +216,12 @@ func TestMigrateEBS(t *testing.T) {
 	resizer.EXPECT().ExtractVolumeID(gomock.Eq("aws://eu-central-1b/ebs-volume-1")).Return("ebs-volume-1", nil)
 	resizer.EXPECT().ExtractVolumeID(gomock.Eq("aws://eu-central-1b/ebs-volume-2")).Return("ebs-volume-2", nil)
 
+	resizer.EXPECT().GetProviderVolumeID(gomock.Any()).
+		DoAndReturn(func(pv *v1.PersistentVolume) (string, error) {
+			return resizer.ExtractVolumeID(pv.Spec.AWSElasticBlockStore.VolumeID)
+		}).
+		Times(2)
+
 	resizer.EXPECT().DescribeVolumes(gomock.Eq([]string{"ebs-volume-1", "ebs-volume-2"})).Return(
 		[]volumes.VolumeProperties{
 			{VolumeID: "ebs-volume-1", VolumeType: "gp2", Size: 100},
@@ -322,6 +328,12 @@ func TestMigrateGp3Support(t *testing.T) {
 	resizer.EXPECT().ExtractVolumeID(gomock.Eq("aws://eu-central-1b/ebs-volume-2")).Return("ebs-volume-2", nil)
 	resizer.EXPECT().ExtractVolumeID(gomock.Eq("aws://eu-central-1b/ebs-volume-3")).Return("ebs-volume-3", nil)
 
+	resizer.EXPECT().GetProviderVolumeID(gomock.Any()).
+		DoAndReturn(func(pv *v1.PersistentVolume) (string, error) {
+			return resizer.ExtractVolumeID(pv.Spec.AWSElasticBlockStore.VolumeID)
+		}).
+		Times(3)
+
 	resizer.EXPECT().DescribeVolumes(gomock.Eq([]string{"ebs-volume-1", "ebs-volume-2", "ebs-volume-3"})).Return(
 		[]volumes.VolumeProperties{
 			{VolumeID: "ebs-volume-1", VolumeType: "gp3", Size: 100, Iops: 3000},
@@ -377,6 +389,12 @@ func TestManualGp2Gp3Support(t *testing.T) {
 	resizer.EXPECT().ExtractVolumeID(gomock.Eq("aws://eu-central-1b/ebs-volume-1")).Return("ebs-volume-1", nil)
 	resizer.EXPECT().ExtractVolumeID(gomock.Eq("aws://eu-central-1b/ebs-volume-2")).Return("ebs-volume-2", nil)
 
+	resizer.EXPECT().GetProviderVolumeID(gomock.Any()).
+		DoAndReturn(func(pv *v1.PersistentVolume) (string, error) {
+			return resizer.ExtractVolumeID(pv.Spec.AWSElasticBlockStore.VolumeID)
+		}).
+		Times(2)
+
 	resizer.EXPECT().DescribeVolumes(gomock.Eq([]string{"ebs-volume-1", "ebs-volume-2"})).Return(
 		[]volumes.VolumeProperties{
 			{VolumeID: "ebs-volume-1", VolumeType: "gp2", Size: 150, Iops: 3000},
@@ -436,6 +454,12 @@ func TestDontTouchType(t *testing.T) {
 	resizer.EXPECT().ExtractVolumeID(gomock.Eq("aws://eu-central-1b/ebs-volume-1")).Return("ebs-volume-1", nil)
 	resizer.EXPECT().ExtractVolumeID(gomock.Eq("aws://eu-central-1b/ebs-volume-2")).Return("ebs-volume-2", nil)
 
+	resizer.EXPECT().GetProviderVolumeID(gomock.Any()).
+		DoAndReturn(func(pv *v1.PersistentVolume) (string, error) {
+			return resizer.ExtractVolumeID(pv.Spec.AWSElasticBlockStore.VolumeID)
+		}).
+		Times(2)
+
 	resizer.EXPECT().DescribeVolumes(gomock.Eq([]string{"ebs-volume-1", "ebs-volume-2"})).Return(
 		[]volumes.VolumeProperties{
 			{VolumeID: "ebs-volume-1", VolumeType: "gp2", Size: 150, Iops: 3000},
diff --git a/pkg/util/constants/aws.go b/pkg/util/constants/aws.go
index f1cfd597..147e5888 100644
--- a/pkg/util/constants/aws.go
+++ b/pkg/util/constants/aws.go
@@ -7,6 +7,7 @@ const (
 	// EBS related constants
 	EBSVolumeIDStart = "/vol-"
 	EBSProvisioner   = "kubernetes.io/aws-ebs"
+	EBSDriver        = "ebs.csi.aws.com"
 	//https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_VolumeModification.html
 	EBSVolumeStateModifying     = "modifying"
 	EBSVolumeStateOptimizing    = "optimizing"
diff --git a/pkg/util/volumes/ebs.go b/pkg/util/volumes/ebs.go
index f625dab2..cb8f8e97 100644
--- a/pkg/util/volumes/ebs.go
+++ b/pkg/util/volumes/ebs.go
@@ -36,7 +36,8 @@ func (r *EBSVolumeResizer) IsConnectedToProvider() bool {
 
 // VolumeBelongsToProvider checks if the given persistent volume is backed by EBS.
 func (r *EBSVolumeResizer) VolumeBelongsToProvider(pv *v1.PersistentVolume) bool {
-	return pv.Spec.AWSElasticBlockStore != nil && pv.Annotations[constants.VolumeStorateProvisionerAnnotation] == constants.EBSProvisioner
+	return (pv.Spec.AWSElasticBlockStore != nil && pv.Annotations[constants.VolumeStorateProvisionerAnnotation] == constants.EBSProvisioner) ||
+		(pv.Spec.CSI != nil && pv.Spec.CSI.Driver == constants.EBSDriver)
 }
 
 // ExtractVolumeID extracts volumeID from "aws://eu-central-1a/vol-075ddfc4a127d0bd4"
@@ -54,7 +55,12 @@ func (r *EBSVolumeResizer) ExtractVolumeID(volumeID string) (string, error) {
 
 // GetProviderVolumeID converts aws://eu-central-1b/vol-00f93d4827217c629 to vol-00f93d4827217c629 for EBS volumes
 func (r *EBSVolumeResizer) GetProviderVolumeID(pv *v1.PersistentVolume) (string, error) {
-	volumeID := pv.Spec.AWSElasticBlockStore.VolumeID
+	var volumeID string = ""
+	if pv.Spec.CSI != nil {
+		volumeID = pv.Spec.CSI.VolumeHandle
+	} else if pv.Spec.AWSElasticBlockStore != nil {
+		volumeID = pv.Spec.AWSElasticBlockStore.VolumeID
+	}
 	if volumeID == "" {
 		return "", fmt.Errorf("got empty volume id for volume %v", pv)
 	}
diff --git a/pkg/util/volumes/ebs_test.go b/pkg/util/volumes/ebs_test.go
new file mode 100644
index 00000000..6f722ff7
--- /dev/null
+++ b/pkg/util/volumes/ebs_test.go
@@ -0,0 +1,123 @@
+package volumes
+
+import (
+	"fmt"
+	"testing"
+	v1 "k8s.io/api/core/v1"
+	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+)
+
+func TestGetProviderVolumeID(t *testing.T) {
+	tests := []struct {
+		name     string
+		pv       *v1.PersistentVolume
+		expected string
+		err      error
+	}{
+		{
+			name: "CSI volume handle",
+			pv: &v1.PersistentVolume{
+				Spec: v1.PersistentVolumeSpec{
+					PersistentVolumeSource: v1.PersistentVolumeSource{
+						CSI: &v1.CSIPersistentVolumeSource{
+							VolumeHandle: "vol-075ddfc4a127d0bd5",
+						},
+					},
+				},
+			},
+			expected: "vol-075ddfc4a127d0bd5",
+			err:      nil,
+		},
+		{
+			name: "AWS EBS volume handle",
+			pv: &v1.PersistentVolume{
+				Spec: v1.PersistentVolumeSpec{
+					PersistentVolumeSource: v1.PersistentVolumeSource{
+						AWSElasticBlockStore: &v1.AWSElasticBlockStoreVolumeSource{
+							VolumeID: "aws://eu-central-1a/vol-075ddfc4a127d0bd4",
+						},
+					},
+				},
+			},
+			expected: "vol-075ddfc4a127d0bd4",
+			err:      nil,
+		},
+		{
+			name: "Empty volume handle",
+			pv: &v1.PersistentVolume{
+				Spec: v1.PersistentVolumeSpec{},
+			},
+			expected: "",
+			err:      fmt.Errorf("got empty volume id for volume %v", &v1.PersistentVolume{}),
+		},
+	}
+
+	resizer := EBSVolumeResizer{}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			volumeID, err := resizer.GetProviderVolumeID(tt.pv)
+			if volumeID != tt.expected || (err != nil && err.Error() != tt.err.Error()) {
+				t.Errorf("expected %v, got %v, expected err %v, got %v", tt.expected, volumeID, tt.err, err)
+			}
+		})
+	}
+}
+
+func TestVolumeBelongsToProvider(t *testing.T) {
+	tests := []struct {
+		name     string
+		pv       *v1.PersistentVolume
+		expected bool
+	}{
+		{
+			name: "CSI volume handle",
+			pv: &v1.PersistentVolume{
+				Spec: v1.PersistentVolumeSpec{
+					PersistentVolumeSource: v1.PersistentVolumeSource{
+						CSI: &v1.CSIPersistentVolumeSource{
+							Driver:       "ebs.csi.aws.com",
+							VolumeHandle: "vol-075ddfc4a127d0bd5",
+						},
+					},
+				},
+			},
+			expected: true,
+		},
+		{
+			name: "AWS EBS volume handle",
+			pv: &v1.PersistentVolume{
+				ObjectMeta: metav1.ObjectMeta{
+					Annotations: map[string]string {
+						"pv.kubernetes.io/provisioned-by": "kubernetes.io/aws-ebs",
+					},
+				},
+				Spec: v1.PersistentVolumeSpec{
+					PersistentVolumeSource: v1.PersistentVolumeSource{
+						AWSElasticBlockStore: &v1.AWSElasticBlockStoreVolumeSource{
+							VolumeID: "aws://eu-central-1a/vol-075ddfc4a127d0bd4",
+						},
+					},
+				},
+			},
+			expected: true,
+		},
+		{
+			name: "Empty volume source",
+			pv: &v1.PersistentVolume{
+				Spec: v1.PersistentVolumeSpec{},
+			},
+			expected: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			resizer := EBSVolumeResizer{}
+			isProvider := resizer.VolumeBelongsToProvider(tt.pv)
+			if isProvider != tt.expected {
+				t.Errorf("expected %v, got %v", tt.expected, isProvider)
+			}
+		})
+	}
+}
-- 
GitLab