diff --git a/pkg/spec/postgresql.go b/pkg/spec/postgresql.go index 7dac7992b3de0e1eff615c6be733a02a0f0ea4d8..c7b5df902edf0fbea2a62bf30c10712838e48250 100644 --- a/pkg/spec/postgresql.go +++ b/pkg/spec/postgresql.go @@ -3,6 +3,7 @@ package spec import ( "encoding/json" "fmt" + "regexp" "strings" "time" @@ -76,6 +77,12 @@ const ( ClusterStatusInvalid PostgresStatus = "Invalid" ) +const ( + serviceNameMaxLength = 63 + clusterNameMaxLength = serviceNameMaxLength - len("-repl") + serviceNameRegexString = `^[a-z]([-a-z0-9]*[a-z0-9])?$` +) + // Postgresql defines PostgreSQL Custom Resource Definition Object. type Postgresql struct { metav1.TypeMeta `json:",inline"` @@ -126,7 +133,10 @@ type PostgresqlList struct { Items []Postgresql `json:"items"` } -var weekdays = map[string]int{"Sun": 0, "Mon": 1, "Tue": 2, "Wed": 3, "Thu": 4, "Fri": 5, "Sat": 6} +var ( + weekdays = map[string]int{"Sun": 0, "Mon": 1, "Tue": 2, "Wed": 3, "Thu": 4, "Fri": 5, "Sat": 6} + serviceNameRegex = regexp.MustCompile(serviceNameRegexString) +) func parseTime(s string) (time.Time, error) { parts := strings.Split(s, ":") @@ -225,10 +235,31 @@ func extractClusterName(clusterName string, teamName string) (string, error) { if strings.ToLower(clusterName[:teamNameLen+1]) != strings.ToLower(teamName)+"-" { return "", fmt.Errorf("name must match {TEAM}-{NAME} format") } + if len(clusterName) > clusterNameMaxLength { + return "", fmt.Errorf("name cannot be longer than %d characters", clusterNameMaxLength) + } + if !serviceNameRegex.MatchString(clusterName) { + return "", fmt.Errorf("name must confirm to DNS-1035, regex used for validation is %q", + serviceNameRegexString) + } return clusterName[teamNameLen+1:], nil } +func validateCloneClusterDescription(clone *CloneDescription) error { + // when cloning from the basebackup (no end timestamp) check that the cluster name is a valid service name + if clone.ClusterName != "" && clone.EndTimestamp == "" { + if !serviceNameRegex.MatchString(clone.ClusterName) { + return fmt.Errorf("clone cluster name must confirm to DNS-1035, regex used for validation is %q", + serviceNameRegexString) + } + if len(clone.ClusterName) > serviceNameMaxLength { + return fmt.Errorf("clone cluster name must be no longer than %d characters", serviceNameMaxLength) + } + } + return nil +} + type postgresqlListCopy PostgresqlList type postgresqlCopy Postgresql @@ -252,22 +283,16 @@ func (p *Postgresql) UnmarshalJSON(data []byte) error { } tmp2 := Postgresql(tmp) - clusterName, err := extractClusterName(tmp2.ObjectMeta.Name, tmp2.Spec.TeamID) - if err == nil { - tmp2.Spec.ClusterName = clusterName - } else { + if clusterName, err := extractClusterName(tmp2.ObjectMeta.Name, tmp2.Spec.TeamID); err != nil { tmp2.Error = err tmp2.Status = ClusterStatusInvalid + } else if err := validateCloneClusterDescription(&tmp2.Spec.Clone); err != nil { + tmp2.Error = err + tmp2.Status = ClusterStatusInvalid + } else { + tmp2.Spec.ClusterName = clusterName } - // The assumption below is that a cluster to clone, if any, belongs to the same team - if tmp2.Spec.Clone.ClusterName != "" { - _, err := extractClusterName(tmp2.Spec.Clone.ClusterName, tmp2.Spec.TeamID) - if err != nil { - tmp2.Error = fmt.Errorf("%s for the cluster to clone", err) - tmp2.Spec.Clone = CloneDescription{} - tmp2.Status = ClusterStatusInvalid - } - } + *p = tmp2 return nil diff --git a/pkg/spec/postgresql_test.go b/pkg/spec/postgresql_test.go index 091334e8e4942e037b0ff34d8a93092e5b9af84c..07251edeb6cb31df57dc640b42b80cbab62f21c7 100644 --- a/pkg/spec/postgresql_test.go +++ b/pkg/spec/postgresql_test.go @@ -43,7 +43,10 @@ var clusterNames = []struct { {"acid-test", "acid", "test", nil}, {"test-my-name", "test", "my-name", nil}, {"my-team-another-test", "my-team", "another-test", nil}, - {"------strange-team-cluster", "-----", "strange-team-cluster", nil}, + {"------strange-team-cluster", "-----", "strange-team-cluster", + errors.New(`name must confirm to DNS-1035, regex used for validation is "^[a-z]([-a-z0-9]*[a-z0-9])?$"`)}, + {"fooobar-fooobarfooobarfooobarfooobarfooobarfooobarfooobarfooobar", "fooobar", "", + errors.New("name cannot be longer than 58 characters")}, {"acid-test", "test", "", errors.New("name must match {TEAM}-{NAME} format")}, {"-test", "", "", errors.New("team name is empty")}, {"-test", "-", "", errors.New("name must match {TEAM}-{NAME} format")}, @@ -51,6 +54,18 @@ var clusterNames = []struct { {"-", "-", "", errors.New("name is too short")}, } +var cloneClusterDescriptions = []struct { + in *CloneDescription + err error +}{ + {&CloneDescription{"foo+bar", "", "NotEmpty"}, nil}, + {&CloneDescription{"foo+bar", "", ""}, + errors.New(`clone cluster name must confirm to DNS-1035, regex used for validation is "^[a-z]([-a-z0-9]*[a-z0-9])?$"`)}, + {&CloneDescription{"foobar123456789012345678901234567890123456789012345678901234567890", "", ""}, + errors.New("clone cluster name must be no longer than 63 characters")}, + {&CloneDescription{"foobar", "", ""}, nil}, +} + var maintenanceWindows = []struct { in []byte out MaintenanceWindow @@ -279,14 +294,15 @@ var unmarshalCluster = []struct { Name: "acid-testcluster1", }, Spec: PostgresSpec{ - TeamID: "acid", - Clone: CloneDescription{}, + TeamID: "acid", + Clone: CloneDescription{ + ClusterName: "team-batman", + }, ClusterName: "testcluster1", }, - Status: ClusterStatusInvalid, - Error: errors.New("name must match {TEAM}-{NAME} format for the cluster to clone"), + Error: nil, }, - marshal: []byte(`{"kind":"Postgresql","apiVersion":"acid.zalan.do/v1","metadata":{"name":"acid-testcluster1","creationTimestamp":null},"spec":{"postgresql":{"version":"","parameters":null},"volume":{"size":"","storageClass":""},"patroni":{"initdb":null,"pg_hba":null,"ttl":0,"loop_wait":0,"retry_timeout":0,"maximum_lag_on_failover":0},"resources":{"requests":{"cpu":"","memory":""},"limits":{"cpu":"","memory":""}},"teamId":"acid","allowedSourceRanges":null,"numberOfInstances":0,"users":null,"clone":{}},"status":"Invalid"}`), err: nil}, + marshal: []byte(`{"kind":"Postgresql","apiVersion":"acid.zalan.do/v1","metadata":{"name":"acid-testcluster1","creationTimestamp":null},"spec":{"postgresql":{"version":"","parameters":null},"volume":{"size":"","storageClass":""},"patroni":{"initdb":null,"pg_hba":null,"ttl":0,"loop_wait":0,"retry_timeout":0,"maximum_lag_on_failover":0},"resources":{"requests":{"cpu":"","memory":""},"limits":{"cpu":"","memory":""}},"teamId":"acid","allowedSourceRanges":null,"numberOfInstances":0,"users":null,"clone":{"cluster":"team-batman"}}}`), err: nil}, {[]byte(`{"kind": "Postgresql","apiVersion": "acid.zalan.do/v1"`), Postgresql{}, []byte{}, @@ -350,11 +366,12 @@ func TestParseTime(t *testing.T) { for _, tt := range parseTimeTests { aTime, err := parseTime(tt.in) if err != nil { - if err.Error() != tt.err.Error() { + if tt.err == nil || err.Error() != tt.err.Error() { t.Errorf("ParseTime expected error: %v, got: %v", tt.err, err) } - continue + } else if tt.err != nil { + t.Errorf("Expected error: %v", tt.err) } if aTime != tt.out { @@ -367,11 +384,12 @@ func TestWeekdayTime(t *testing.T) { for _, tt := range parseWeekdayTests { aTime, err := parseWeekday(tt.in) if err != nil { - if err.Error() != tt.err.Error() { + if tt.err == nil || err.Error() != tt.err.Error() { t.Errorf("ParseWeekday expected error: %v, got: %v", tt.err, err) } - continue + } else if tt.err != nil { + t.Errorf("Expected error: %v", tt.err) } if aTime != tt.out { @@ -383,9 +401,13 @@ func TestWeekdayTime(t *testing.T) { func TestClusterName(t *testing.T) { for _, tt := range clusterNames { name, err := extractClusterName(tt.in, tt.inTeam) - if err != nil && err.Error() != tt.err.Error() { - t.Errorf("extractClusterName expected error: %v, got: %v", tt.err, err) + if err != nil { + if tt.err == nil || err.Error() != tt.err.Error() { + t.Errorf("extractClusterName expected error: %v, got: %v", tt.err, err) + } continue + } else if tt.err != nil { + t.Errorf("Expected error: %v", tt.err) } if name != tt.clusterName { t.Errorf("Expected cluserName: %q, got: %q", tt.clusterName, name) @@ -393,17 +415,29 @@ func TestClusterName(t *testing.T) { } } +func TestCloneClusterDescription(t *testing.T) { + for _, tt := range cloneClusterDescriptions { + if err := validateCloneClusterDescription(tt.in); err != nil { + if tt.err == nil || err.Error() != tt.err.Error() { + t.Errorf("testCloneClusterDescription expected error: %v, got: %v", tt.err, err) + } + } else if tt.err != nil { + t.Errorf("Expected error: %v", tt.err) + } + } +} + func TestUnmarshalMaintenanceWindow(t *testing.T) { for _, tt := range maintenanceWindows { var m MaintenanceWindow err := m.UnmarshalJSON(tt.in) - if err != nil && err.Error() != tt.err.Error() { - t.Errorf("MaintenanceWindow unmarshal expected error: %v, got %v", tt.err, err) - continue - } - if tt.err != nil && err == nil { - t.Errorf("Expected error") + if err != nil { + if tt.err == nil || err.Error() != tt.err.Error() { + t.Errorf("MaintenanceWindow unmarshal expected error: %v, got %v", tt.err, err) + } continue + } else if tt.err != nil { + t.Errorf("Expected error: %v", tt.err) } if !reflect.DeepEqual(m, tt.out) { @@ -421,7 +455,6 @@ func TestMarshalMaintenanceWindow(t *testing.T) { s, err := tt.out.MarshalJSON() if err != nil { t.Errorf("Marshal Error: %v", err) - continue } if !bytes.Equal(s, tt.in) { @@ -435,11 +468,12 @@ func TestPostgresUnmarshal(t *testing.T) { var cluster Postgresql err := cluster.UnmarshalJSON(tt.in) if err != nil { - if err.Error() != tt.err.Error() { + if tt.err == nil || err.Error() != tt.err.Error() { t.Errorf("Unmarshal expected error: %v, got: %v", tt.err, err) } - continue + } else if tt.err != nil { + t.Errorf("Expected error: %v", tt.err) } if !reflect.DeepEqual(cluster, tt.out) { @@ -457,7 +491,6 @@ func TestMarshal(t *testing.T) { m, err := json.Marshal(tt.out) if err != nil { t.Errorf("Marshal error: %v", err) - continue } if !bytes.Equal(m, tt.marshal) { t.Errorf("Marshal Postgresql expected: %q, got: %q", string(tt.marshal), string(m)) @@ -481,10 +514,15 @@ func TestUnmarshalPostgresList(t *testing.T) { for _, tt := range postgresqlList { var list PostgresqlList err := list.UnmarshalJSON(tt.in) - if err != nil && err.Error() != tt.err.Error() { - t.Errorf("PostgresqlList unmarshal expected error: %v, got: %v", tt.err, err) - return + if err != nil { + if tt.err == nil || err.Error() != tt.err.Error() { + t.Errorf("PostgresqlList unmarshal expected error: %v, got: %v", tt.err, err) + } + continue + } else if tt.err != nil { + t.Errorf("Expected error: %v", tt.err) } + if !reflect.DeepEqual(list, tt.out) { t.Errorf("Postgresql list unmarshall expected: %#v, got: %#v", tt.out, list) }