diff --git a/cluster-autoscaler/cloudprovider/azure/azure_cache.go b/cluster-autoscaler/cloudprovider/azure/azure_cache.go index 334048164dfc73abd072654f29bf0fdf6df625ff..cc3ded38e33f63d816f8443a0b8ada99234a9218 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_cache.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_cache.go @@ -25,6 +25,7 @@ import ( "sync" "time" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "github.com/Azure/go-autorest/autorest/to" "github.com/Azure/skewer" @@ -67,13 +68,18 @@ type azureCache struct { // Cache content. - // resourceGroup specifies the name of the resource group that this cache tracks - resourceGroup string + // resourceGroup specifies the name of the node resource group that this cache tracks + resourceGroup string + clusterResourceGroup string + clusterName string + + // enableVMsAgentPool specifies whether VMs agent pool type is supported. + enableVMsAgentPool bool // vmType can be one of vmTypeVMSS (default), vmTypeStandard vmType string - vmsPoolSet map[string]struct{} // track the nodepools that're vms pool + vmsPoolMap map[string]armcontainerservice.AgentPool // track the nodepools that're vms pool // scaleSets keeps the set of all known scalesets in the resource group, populated/refreshed via VMSS.List() call. // It is only used/populated if vmType is vmTypeVMSS (default). @@ -106,8 +112,11 @@ func newAzureCache(client *azClient, cacheTTL time.Duration, config Config) (*az azClient: client, refreshInterval: cacheTTL, resourceGroup: config.ResourceGroup, + clusterResourceGroup: config.ClusterResourceGroup, + clusterName: config.ClusterName, + enableVMsAgentPool: config.EnableVMsAgentPool, vmType: config.VMType, - vmsPoolSet: make(map[string]struct{}), + vmsPoolMap: make(map[string]armcontainerservice.AgentPool), scaleSets: make(map[string]compute.VirtualMachineScaleSet), virtualMachines: make(map[string][]compute.VirtualMachine), registeredNodeGroups: make([]cloudprovider.NodeGroup, 0), @@ -130,11 +139,11 @@ func newAzureCache(client *azClient, cacheTTL time.Duration, config Config) (*az return cache, nil } -func (m *azureCache) getVMsPoolSet() map[string]struct{} { +func (m *azureCache) getVMsPoolMap() map[string]armcontainerservice.AgentPool { m.mutex.Lock() defer m.mutex.Unlock() - return m.vmsPoolSet + return m.vmsPoolMap } func (m *azureCache) getVirtualMachines() map[string][]compute.VirtualMachine { @@ -232,13 +241,20 @@ func (m *azureCache) fetchAzureResources() error { return err } m.scaleSets = vmssResult - vmResult, vmsPoolSet, err := m.fetchVirtualMachines() + vmResult, err := m.fetchVirtualMachines() if err != nil { return err } // we fetch both sets of resources since CAS may operate on mixed nodepools m.virtualMachines = vmResult - m.vmsPoolSet = vmsPoolSet + // fetch VMs pools if enabled + if m.enableVMsAgentPool { + vmsPoolMap, err := m.fetchVMsPools() + if err != nil { + return err + } + m.vmsPoolMap = vmsPoolMap + } return nil } @@ -251,19 +267,17 @@ const ( ) // fetchVirtualMachines returns the updated list of virtual machines in the config resource group using the Azure API. -func (m *azureCache) fetchVirtualMachines() (map[string][]compute.VirtualMachine, map[string]struct{}, error) { +func (m *azureCache) fetchVirtualMachines() (map[string][]compute.VirtualMachine, error) { ctx, cancel := getContextWithCancel() defer cancel() result, err := m.azClient.virtualMachinesClient.List(ctx, m.resourceGroup) if err != nil { klog.Errorf("VirtualMachinesClient.List in resource group %q failed: %v", m.resourceGroup, err) - return nil, nil, err.Error() + return nil, err.Error() } instances := make(map[string][]compute.VirtualMachine) - // track the nodepools that're vms pools - vmsPoolSet := make(map[string]struct{}) for _, instance := range result { if instance.Tags == nil { continue @@ -280,20 +294,43 @@ func (m *azureCache) fetchVirtualMachines() (map[string][]compute.VirtualMachine } instances[to.String(vmPoolName)] = append(instances[to.String(vmPoolName)], instance) + } + return instances, nil +} - // if the nodepool is already in the map, skip it - if _, ok := vmsPoolSet[to.String(vmPoolName)]; ok { - continue +// fetchVMsPools returns a name to agentpool map of all the VMs pools in the cluster +func (m *azureCache) fetchVMsPools() (map[string]armcontainerservice.AgentPool, error) { + ctx, cancel := getContextWithTimeout(vmsContextTimeout) + defer cancel() + + // defensive check, should never happen when enableVMsAgentPool toggle is on + if m.azClient.agentPoolClient == nil { + return nil, errors.New("agentPoolClient is nil") + } + + vmsPoolMap := make(map[string]armcontainerservice.AgentPool) + pager := m.azClient.agentPoolClient.NewListPager(m.clusterResourceGroup, m.clusterName, nil) + var aps []*armcontainerservice.AgentPool + for pager.More() { + resp, err := pager.NextPage(ctx) + if err != nil { + klog.Errorf("agentPoolClient.pager.NextPage in cluster %s resource group %s failed: %v", + m.clusterName, m.clusterResourceGroup, err) + return nil, err } + aps = append(aps, resp.Value...) + } - // nodes from vms pool will have tag "aks-managed-agentpool-type" set to "VirtualMachines" - if agentpoolType := tags[agentpoolTypeTag]; agentpoolType != nil { - if strings.EqualFold(to.String(agentpoolType), vmsPoolType) { - vmsPoolSet[to.String(vmPoolName)] = struct{}{} - } + for _, ap := range aps { + if ap != nil && ap.Name != nil && ap.Properties != nil && ap.Properties.Type != nil && + *ap.Properties.Type == armcontainerservice.AgentPoolTypeVirtualMachines { + // we only care about VMs pools, skip other types + klog.V(6).Infof("Found VMs pool %q", *ap.Name) + vmsPoolMap[*ap.Name] = *ap } } - return instances, vmsPoolSet, nil + + return vmsPoolMap, nil } // fetchScaleSets returns the updated list of scale sets in the config resource group using the Azure API. @@ -422,7 +459,7 @@ func (m *azureCache) HasInstance(providerID string) (bool, error) { // FindForInstance returns node group of the given Instance func (m *azureCache) FindForInstance(instance *azureRef, vmType string) (cloudprovider.NodeGroup, error) { - vmsPoolSet := m.getVMsPoolSet() + vmsPoolMap := m.getVMsPoolMap() m.mutex.Lock() defer m.mutex.Unlock() @@ -441,7 +478,7 @@ func (m *azureCache) FindForInstance(instance *azureRef, vmType string) (cloudpr } // cluster with vmss pool only - if vmType == providerazureconsts.VMTypeVMSS && len(vmsPoolSet) == 0 { + if vmType == providerazureconsts.VMTypeVMSS && len(vmsPoolMap) == 0 { if m.areAllScaleSetsUniform() { // Omit virtual machines not managed by vmss only in case of uniform scale set. if ok := virtualMachineRE.Match([]byte(inst.Name)); ok { diff --git a/cluster-autoscaler/cloudprovider/azure/azure_cache_test.go b/cluster-autoscaler/cloudprovider/azure/azure_cache_test.go index 3fd801afbe8f729c8ad73bbcefdcd33f6242ff8b..843b54e1cf2891c4022c2b2aec839a365c60836a 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_cache_test.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_cache_test.go @@ -22,9 +22,42 @@ import ( "k8s.io/autoscaler/cluster-autoscaler/cloudprovider" providerazureconsts "sigs.k8s.io/cloud-provider-azure/pkg/consts" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5" + "github.com/Azure/go-autorest/autorest/to" "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" ) +func TestFetchVMsPools(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + provider := newTestProvider(t) + ac := provider.azureManager.azureCache + mockAgentpoolclient := NewMockAgentPoolsClient(ctrl) + ac.azClient.agentPoolClient = mockAgentpoolclient + + vmsPool := getTestVMsAgentPool(false) + vmssPoolType := armcontainerservice.AgentPoolTypeVirtualMachineScaleSets + vmssPool := armcontainerservice.AgentPool{ + Name: to.StringPtr("vmsspool1"), + Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{ + Type: &vmssPoolType, + }, + } + invalidPool := armcontainerservice.AgentPool{} + fakeAPListPager := getFakeAgentpoolListPager(&vmsPool, &vmssPool, &invalidPool) + mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil). + Return(fakeAPListPager) + + vmsPoolMap, err := ac.fetchVMsPools() + assert.NoError(t, err) + assert.Equal(t, 1, len(vmsPoolMap)) + + _, ok := vmsPoolMap[to.String(vmsPool.Name)] + assert.True(t, ok) +} + func TestRegister(t *testing.T) { provider := newTestProvider(t) ss := newTestScaleSet(provider.azureManager, "ss") diff --git a/cluster-autoscaler/cloudprovider/azure/azure_client.go b/cluster-autoscaler/cloudprovider/azure/azure_client.go index fbc39a62ed28615f3d3345e2305ea31fb5086d49..8281b59c8091142ce8f710eaad65fc5c33214aaf 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_client.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_client.go @@ -19,6 +19,8 @@ package azure import ( "context" "fmt" + "os" + "time" _ "go.uber.org/mock/mockgen/model" // for go:generate @@ -29,7 +31,7 @@ import ( azurecore_policy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2019-07-01/compute" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/azure" @@ -47,7 +49,12 @@ import ( providerazureconfig "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" ) -//go:generate sh -c "mockgen k8s.io/autoscaler/cluster-autoscaler/cloudprovider/azure AgentPoolsClient >./agentpool_client.go" +//go:generate sh -c "mockgen -source=azure_client.go -destination azure_mock_agentpool_client.go -package azure -exclude_interfaces DeploymentsClient" + +const ( + vmsContextTimeout = 5 * time.Minute + vmsAsyncContextTimeout = 30 * time.Minute +) // AgentPoolsClient interface defines the methods needed for scaling vms pool. // it is implemented by track2 sdk armcontainerservice.AgentPoolsClient @@ -68,52 +75,89 @@ type AgentPoolsClient interface { machines armcontainerservice.AgentPoolDeleteMachinesParameter, options *armcontainerservice.AgentPoolsClientBeginDeleteMachinesOptions) ( *runtime.Poller[armcontainerservice.AgentPoolsClientDeleteMachinesResponse], error) + NewListPager( + resourceGroupName, resourceName string, + options *armcontainerservice.AgentPoolsClientListOptions, + ) *runtime.Pager[armcontainerservice.AgentPoolsClientListResponse] } func getAgentpoolClientCredentials(cfg *Config) (azcore.TokenCredential, error) { - var cred azcore.TokenCredential - var err error - if cfg.AuthMethod == authMethodCLI { - cred, err = azidentity.NewAzureCLICredential(&azidentity.AzureCLICredentialOptions{ - TenantID: cfg.TenantID}) - if err != nil { - klog.Errorf("NewAzureCLICredential failed: %v", err) - return nil, err + if cfg.AuthMethod == "" || cfg.AuthMethod == authMethodPrincipal { + // Use MSI + if cfg.UseManagedIdentityExtension { + // Use System Assigned MSI + if cfg.UserAssignedIdentityID == "" { + klog.V(4).Info("Agentpool client: using System Assigned MSI to retrieve access token") + return azidentity.NewManagedIdentityCredential(nil) + } + // Use User Assigned MSI + klog.V(4).Info("Agentpool client: using User Assigned MSI to retrieve access token") + return azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ + ID: azidentity.ClientID(cfg.UserAssignedIdentityID), + }) } - } else if cfg.AuthMethod == "" || cfg.AuthMethod == authMethodPrincipal { - cred, err = azidentity.NewClientSecretCredential(cfg.TenantID, cfg.AADClientID, cfg.AADClientSecret, nil) - if err != nil { - klog.Errorf("NewClientSecretCredential failed: %v", err) - return nil, err + + // Use Service Principal with ClientID and ClientSecret + if cfg.AADClientID != "" && cfg.AADClientSecret != "" { + klog.V(2).Infoln("Agentpool client: using client_id+client_secret to retrieve access token") + return azidentity.NewClientSecretCredential(cfg.TenantID, cfg.AADClientID, cfg.AADClientSecret, nil) } - } else { - return nil, fmt.Errorf("unsupported authorization method: %s", cfg.AuthMethod) - } - return cred, nil -} -func getAgentpoolClientRetryOptions(cfg *Config) azurecore_policy.RetryOptions { - if cfg.AuthMethod == authMethodCLI { - return azurecore_policy.RetryOptions{ - MaxRetries: -1, // no retry when using CLI auth for UT + // Use Service Principal with ClientCert and AADClientCertPassword + if cfg.AADClientID != "" && cfg.AADClientCertPath != "" { + klog.V(2).Infoln("Agentpool client: using client_cert+client_private_key to retrieve access token") + certData, err := os.ReadFile(cfg.AADClientCertPath) + if err != nil { + return nil, fmt.Errorf("reading the client certificate from file %s failed with error: %w", cfg.AADClientCertPath, err) + } + certs, privateKey, err := azidentity.ParseCertificates(certData, []byte(cfg.AADClientCertPassword)) + if err != nil { + return nil, fmt.Errorf("parsing service principal certificate data failed with error: %w", err) + } + return azidentity.NewClientCertificateCredential(cfg.TenantID, cfg.AADClientID, certs, privateKey, &azidentity.ClientCertificateCredentialOptions{ + SendCertificateChain: true, + }) } } - return azextensions.DefaultRetryOpts() + + if cfg.UseFederatedWorkloadIdentityExtension { + klog.V(4).Info("Agentpool client: using workload identity for access token") + return azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{ + TokenFilePath: cfg.AADFederatedTokenFile, + }) + } + + return nil, fmt.Errorf("unsupported authorization method: %s", cfg.AuthMethod) } func newAgentpoolClient(cfg *Config) (AgentPoolsClient, error) { - retryOptions := getAgentpoolClientRetryOptions(cfg) + retryOptions := azextensions.DefaultRetryOpts() + cred, err := getAgentpoolClientCredentials(cfg) + if err != nil { + klog.Errorf("failed to get agent pool client credentials: %v", err) + return nil, err + } + + env := azure.PublicCloud // default to public cloud + if cfg.Cloud != "" { + var err error + env, err = azure.EnvironmentFromName(cfg.Cloud) + if err != nil { + klog.Errorf("failed to get environment from name %s: with error: %v", cfg.Cloud, err) + return nil, err + } + } if cfg.ARMBaseURLForAPClient != "" { klog.V(10).Infof("Using ARMBaseURLForAPClient to create agent pool client") - return newAgentpoolClientWithConfig(cfg.SubscriptionID, nil, cfg.ARMBaseURLForAPClient, "UNKNOWN", retryOptions) + return newAgentpoolClientWithConfig(cfg.SubscriptionID, cred, cfg.ARMBaseURLForAPClient, env.TokenAudience, retryOptions, true /*insecureAllowCredentialWithHTTP*/) } - return newAgentpoolClientWithPublicEndpoint(cfg, retryOptions) + return newAgentpoolClientWithConfig(cfg.SubscriptionID, cred, env.ResourceManagerEndpoint, env.TokenAudience, retryOptions, false /*insecureAllowCredentialWithHTTP*/) } func newAgentpoolClientWithConfig(subscriptionID string, cred azcore.TokenCredential, - cloudCfgEndpoint, cloudCfgAudience string, retryOptions azurecore_policy.RetryOptions) (AgentPoolsClient, error) { + cloudCfgEndpoint, cloudCfgAudience string, retryOptions azurecore_policy.RetryOptions, insecureAllowCredentialWithHTTP bool) (AgentPoolsClient, error) { agentPoolsClient, err := armcontainerservice.NewAgentPoolsClient(subscriptionID, cred, &policy.ClientOptions{ ClientOptions: azurecore_policy.ClientOptions{ @@ -125,9 +169,10 @@ func newAgentpoolClientWithConfig(subscriptionID string, cred azcore.TokenCreden }, }, }, - Telemetry: azextensions.DefaultTelemetryOpts(getUserAgentExtension()), - Transport: azextensions.DefaultHTTPClient(), - Retry: retryOptions, + InsecureAllowCredentialWithHTTP: insecureAllowCredentialWithHTTP, + Telemetry: azextensions.DefaultTelemetryOpts(getUserAgentExtension()), + Transport: azextensions.DefaultHTTPClient(), + Retry: retryOptions, }, }) @@ -139,26 +184,6 @@ func newAgentpoolClientWithConfig(subscriptionID string, cred azcore.TokenCreden return agentPoolsClient, nil } -func newAgentpoolClientWithPublicEndpoint(cfg *Config, retryOptions azurecore_policy.RetryOptions) (AgentPoolsClient, error) { - cred, err := getAgentpoolClientCredentials(cfg) - if err != nil { - klog.Errorf("failed to get agent pool client credentials: %v", err) - return nil, err - } - - // default to public cloud - env := azure.PublicCloud - if cfg.Cloud != "" { - env, err = azure.EnvironmentFromName(cfg.Cloud) - if err != nil { - klog.Errorf("failed to get environment from name %s: with error: %v", cfg.Cloud, err) - return nil, err - } - } - - return newAgentpoolClientWithConfig(cfg.SubscriptionID, cred, env.ResourceManagerEndpoint, env.TokenAudience, retryOptions) -} - type azClient struct { virtualMachineScaleSetsClient vmssclient.Interface virtualMachineScaleSetVMsClient vmssvmclient.Interface @@ -232,9 +257,11 @@ func newAzClient(cfg *Config, env *azure.Environment) (*azClient, error) { agentPoolClient, err := newAgentpoolClient(cfg) if err != nil { - // we don't want to fail the whole process so we don't break any existing functionality - // since this may not be fatal - it is only used by vms pool which is still under development. - klog.Warningf("newAgentpoolClient failed with error: %s", err) + klog.Errorf("newAgentpoolClient failed with error: %s", err) + if cfg.EnableVMsAgentPool { + // only return error if VMs agent pool is supported which is controlled by toggle + return nil, err + } } return &azClient{ diff --git a/cluster-autoscaler/cloudprovider/azure/azure_cloud_provider_test.go b/cluster-autoscaler/cloudprovider/azure/azure_cloud_provider_test.go index cd88602da4791db125e43bb91cea3242723d1df8..342f4988e8cfe87fd1be1902f0eb0a8d2099fb91 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_cloud_provider_test.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_cloud_provider_test.go @@ -20,6 +20,7 @@ import ( "fmt" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "github.com/Azure/azure-sdk-for-go/services/resources/mgmt/2017-05-10/resources" "github.com/Azure/go-autorest/autorest/to" @@ -132,7 +133,7 @@ func TestNodeGroups(t *testing.T) { ) assert.True(t, registered) registered = provider.azureManager.RegisterNodeGroup( - newTestVMsPool(provider.azureManager, "test-vms-pool"), + newTestVMsPool(provider.azureManager), ) assert.True(t, registered) assert.Equal(t, len(provider.NodeGroups()), 2) @@ -146,9 +147,14 @@ func TestHasInstance(t *testing.T) { mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) mockVMClient := mockvmclient.NewMockInterface(ctrl) mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl) + mockAgentpoolclient := NewMockAgentPoolsClient(ctrl) provider.azureManager.azClient.virtualMachinesClient = mockVMClient provider.azureManager.azClient.virtualMachineScaleSetsClient = mockVMSSClient provider.azureManager.azClient.virtualMachineScaleSetVMsClient = mockVMSSVMClient + provider.azureManager.azClient.agentPoolClient = mockAgentpoolclient + provider.azureManager.azureCache.clusterName = "test-cluster" + provider.azureManager.azureCache.clusterResourceGroup = "test-rg" + provider.azureManager.azureCache.enableVMsAgentPool = true // enable VMs agent pool to support mixed node group types // Simulate node groups and instances expectedScaleSets := newTestVMSSList(3, "test-asg", "eastus", compute.Uniform) @@ -158,6 +164,20 @@ func TestHasInstance(t *testing.T) { mockVMSSClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedScaleSets, nil).AnyTimes() mockVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedVMsPoolVMs, nil).AnyTimes() mockVMSSVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup, "test-asg", gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + vmssType := armcontainerservice.AgentPoolTypeVirtualMachineScaleSets + vmssPool := armcontainerservice.AgentPool{ + Name: to.StringPtr("test-asg"), + Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{ + Type: &vmssType, + }, + } + + vmsPool := getTestVMsAgentPool(false) + fakeAPListPager := getFakeAgentpoolListPager(&vmssPool, &vmsPool) + mockAgentpoolclient.EXPECT().NewListPager( + provider.azureManager.azureCache.clusterResourceGroup, + provider.azureManager.azureCache.clusterName, nil). + Return(fakeAPListPager).AnyTimes() // Register node groups assert.Equal(t, len(provider.NodeGroups()), 0) @@ -168,9 +188,9 @@ func TestHasInstance(t *testing.T) { assert.True(t, registered) registered = provider.azureManager.RegisterNodeGroup( - newTestVMsPool(provider.azureManager, "test-vms-pool"), + newTestVMsPool(provider.azureManager), ) - provider.azureManager.explicitlyConfigured["test-vms-pool"] = true + provider.azureManager.explicitlyConfigured[vmsNodeGroupName] = true assert.True(t, registered) assert.Equal(t, len(provider.NodeGroups()), 2) @@ -264,9 +284,14 @@ func TestMixedNodeGroups(t *testing.T) { mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) mockVMClient := mockvmclient.NewMockInterface(ctrl) mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl) + mockAgentpoolclient := NewMockAgentPoolsClient(ctrl) provider.azureManager.azClient.virtualMachinesClient = mockVMClient provider.azureManager.azClient.virtualMachineScaleSetsClient = mockVMSSClient provider.azureManager.azClient.virtualMachineScaleSetVMsClient = mockVMSSVMClient + provider.azureManager.azureCache.clusterName = "test-cluster" + provider.azureManager.azureCache.clusterResourceGroup = "test-rg" + provider.azureManager.azureCache.enableVMsAgentPool = true // enable VMs agent pool to support mixed node group types + provider.azureManager.azClient.agentPoolClient = mockAgentpoolclient expectedScaleSets := newTestVMSSList(3, "test-asg", "eastus", compute.Uniform) expectedVMsPoolVMs := newTestVMsPoolVMList(3) @@ -276,6 +301,19 @@ func TestMixedNodeGroups(t *testing.T) { mockVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup).Return(expectedVMsPoolVMs, nil).AnyTimes() mockVMSSVMClient.EXPECT().List(gomock.Any(), provider.azureManager.config.ResourceGroup, "test-asg", gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + vmssType := armcontainerservice.AgentPoolTypeVirtualMachineScaleSets + vmssPool := armcontainerservice.AgentPool{ + Name: to.StringPtr("test-asg"), + Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{ + Type: &vmssType, + }, + } + + vmsPool := getTestVMsAgentPool(false) + fakeAPListPager := getFakeAgentpoolListPager(&vmssPool, &vmsPool) + mockAgentpoolclient.EXPECT().NewListPager(provider.azureManager.azureCache.clusterResourceGroup, provider.azureManager.azureCache.clusterName, nil). + Return(fakeAPListPager).AnyTimes() + assert.Equal(t, len(provider.NodeGroups()), 0) registered := provider.azureManager.RegisterNodeGroup( newTestScaleSet(provider.azureManager, "test-asg"), @@ -284,9 +322,9 @@ func TestMixedNodeGroups(t *testing.T) { assert.True(t, registered) registered = provider.azureManager.RegisterNodeGroup( - newTestVMsPool(provider.azureManager, "test-vms-pool"), + newTestVMsPool(provider.azureManager), ) - provider.azureManager.explicitlyConfigured["test-vms-pool"] = true + provider.azureManager.explicitlyConfigured[vmsNodeGroupName] = true assert.True(t, registered) assert.Equal(t, len(provider.NodeGroups()), 2) @@ -307,7 +345,7 @@ func TestMixedNodeGroups(t *testing.T) { group, err = provider.NodeGroupForNode(vmsPoolNode) assert.NoError(t, err) assert.NotNil(t, group, "Group should not be nil") - assert.Equal(t, group.Id(), "test-vms-pool") + assert.Equal(t, group.Id(), vmsNodeGroupName) assert.Equal(t, group.MinSize(), 3) assert.Equal(t, group.MaxSize(), 10) } diff --git a/cluster-autoscaler/cloudprovider/azure/azure_config.go b/cluster-autoscaler/cloudprovider/azure/azure_config.go index b19da0fd2d9d9e4a7aaa15b7424195c06ee3ddd5..23c798e245440432a3a70ba229b3bf62d6745022 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_config.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_config.go @@ -86,6 +86,9 @@ type Config struct { // EnableForceDelete defines whether to enable force deletion on the APIs EnableForceDelete bool `json:"enableForceDelete,omitempty" yaml:"enableForceDelete,omitempty"` + // EnableVMsAgentPool defines whether to support VMs agentpool type in addition to VMSS type + EnableVMsAgentPool bool `json:"enableVMsAgentPool,omitempty" yaml:"enableVMsAgentPool,omitempty"` + // (DEPRECATED, DO NOT USE) EnableDynamicInstanceList defines whether to enable dynamic instance workflow for instance information check EnableDynamicInstanceList bool `json:"enableDynamicInstanceList,omitempty" yaml:"enableDynamicInstanceList,omitempty"` @@ -122,6 +125,7 @@ func BuildAzureConfig(configReader io.Reader) (*Config, error) { // Static defaults cfg.EnableDynamicInstanceList = false cfg.EnableVmssFlexNodes = false + cfg.EnableVMsAgentPool = false cfg.CloudProviderBackoffRetries = providerazureconsts.BackoffRetriesDefault cfg.CloudProviderBackoffExponent = providerazureconsts.BackoffExponentDefault cfg.CloudProviderBackoffDuration = providerazureconsts.BackoffDurationDefault @@ -257,6 +261,9 @@ func BuildAzureConfig(configReader io.Reader) (*Config, error) { if _, err = assignBoolFromEnvIfExists(&cfg.StrictCacheUpdates, "AZURE_STRICT_CACHE_UPDATES"); err != nil { return nil, err } + if _, err = assignBoolFromEnvIfExists(&cfg.EnableVMsAgentPool, "AZURE_ENABLE_VMS_AGENT_POOLS"); err != nil { + return nil, err + } if _, err = assignBoolFromEnvIfExists(&cfg.EnableDynamicInstanceList, "AZURE_ENABLE_DYNAMIC_INSTANCE_LIST"); err != nil { return nil, err } diff --git a/cluster-autoscaler/cloudprovider/azure/azure_instance.go b/cluster-autoscaler/cloudprovider/azure/azure_instance.go index 9a04c441b51fc56b25200bc48e834e61b9db4d0c..f6e7b5fb0862579f775f011f2bea74de892f6dec 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_instance.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_instance.go @@ -22,80 +22,79 @@ import ( "regexp" "strings" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "k8s.io/klog/v2" ) -// GetVMSSTypeStatically uses static list of vmss generated at azure_instance_types.go to fetch vmss instance information. +// GetInstanceTypeStatically uses static list of vmss generated at azure_instance_types.go to fetch vmss instance information. // It is declared as a variable for testing purpose. -var GetVMSSTypeStatically = func(template compute.VirtualMachineScaleSet) (*InstanceType, error) { - var vmssType *InstanceType +var GetInstanceTypeStatically = func(template NodeTemplate) (*InstanceType, error) { + var instanceType *InstanceType for k := range InstanceTypes { - if strings.EqualFold(k, *template.Sku.Name) { - vmssType = InstanceTypes[k] + if strings.EqualFold(k, template.SkuName) { + instanceType = InstanceTypes[k] break } } promoRe := regexp.MustCompile(`(?i)_promo`) - if promoRe.MatchString(*template.Sku.Name) { - if vmssType == nil { + if promoRe.MatchString(template.SkuName) { + if instanceType == nil { // We didn't find an exact match but this is a promo type, check for matching standard - klog.V(4).Infof("No exact match found for %s, checking standard types", *template.Sku.Name) - skuName := promoRe.ReplaceAllString(*template.Sku.Name, "") + klog.V(4).Infof("No exact match found for %s, checking standard types", template.SkuName) + skuName := promoRe.ReplaceAllString(template.SkuName, "") for k := range InstanceTypes { if strings.EqualFold(k, skuName) { - vmssType = InstanceTypes[k] + instanceType = InstanceTypes[k] break } } } } - if vmssType == nil { - return vmssType, fmt.Errorf("instance type %q not supported", *template.Sku.Name) + if instanceType == nil { + return instanceType, fmt.Errorf("instance type %q not supported", template.SkuName) } - return vmssType, nil + return instanceType, nil } -// GetVMSSTypeDynamically fetched vmss instance information using sku api calls. +// GetInstanceTypeDynamically fetched vmss instance information using sku api calls. // It is declared as a variable for testing purpose. -var GetVMSSTypeDynamically = func(template compute.VirtualMachineScaleSet, azCache *azureCache) (InstanceType, error) { +var GetInstanceTypeDynamically = func(template NodeTemplate, azCache *azureCache) (InstanceType, error) { ctx := context.Background() - var vmssType InstanceType + var instanceType InstanceType - sku, err := azCache.GetSKU(ctx, *template.Sku.Name, *template.Location) + sku, err := azCache.GetSKU(ctx, template.SkuName, template.Location) if err != nil { // We didn't find an exact match but this is a promo type, check for matching standard promoRe := regexp.MustCompile(`(?i)_promo`) - skuName := promoRe.ReplaceAllString(*template.Sku.Name, "") - if skuName != *template.Sku.Name { - klog.V(1).Infof("No exact match found for %q, checking standard type %q. Error %v", *template.Sku.Name, skuName, err) - sku, err = azCache.GetSKU(ctx, skuName, *template.Location) + skuName := promoRe.ReplaceAllString(template.SkuName, "") + if skuName != template.SkuName { + klog.V(1).Infof("No exact match found for %q, checking standard type %q. Error %v", template.SkuName, skuName, err) + sku, err = azCache.GetSKU(ctx, skuName, template.Location) } if err != nil { - return vmssType, fmt.Errorf("instance type %q not supported. Error %v", *template.Sku.Name, err) + return instanceType, fmt.Errorf("instance type %q not supported. Error %v", template.SkuName, err) } } - vmssType.VCPU, err = sku.VCPU() + instanceType.VCPU, err = sku.VCPU() if err != nil { - klog.V(1).Infof("Failed to parse vcpu from sku %q %v", *template.Sku.Name, err) - return vmssType, err + klog.V(1).Infof("Failed to parse vcpu from sku %q %v", template.SkuName, err) + return instanceType, err } gpu, err := getGpuFromSku(sku) if err != nil { - klog.V(1).Infof("Failed to parse gpu from sku %q %v", *template.Sku.Name, err) - return vmssType, err + klog.V(1).Infof("Failed to parse gpu from sku %q %v", template.SkuName, err) + return instanceType, err } - vmssType.GPU = gpu + instanceType.GPU = gpu memoryGb, err := sku.Memory() if err != nil { - klog.V(1).Infof("Failed to parse memoryMb from sku %q %v", *template.Sku.Name, err) - return vmssType, err + klog.V(1).Infof("Failed to parse memoryMb from sku %q %v", template.SkuName, err) + return instanceType, err } - vmssType.MemoryMb = int64(memoryGb) * 1024 + instanceType.MemoryMb = int64(memoryGb) * 1024 - return vmssType, nil + return instanceType, nil } diff --git a/cluster-autoscaler/cloudprovider/azure/azure_manager.go b/cluster-autoscaler/cloudprovider/azure/azure_manager.go index ad9a5d836f02e66e4182dda624a5297a3e23a3ee..4ea25d1eb1b3ace4b26c846b973486ce9110cb14 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_manager.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_manager.go @@ -168,6 +168,23 @@ func (m *AzureManager) fetchExplicitNodeGroups(specs []string) error { return nil } +// parseSKUAndVMsAgentpoolNameFromSpecName parses the spec name for a mixed-SKU VMs pool. +// The spec name should be in the format <agentpoolname>/<sku>, e.g., "mypool1/Standard_D2s_v3", if the agent pool is a VMs pool. +// This method returns a boolean indicating if the agent pool is a VMs pool, along with the agent pool name and SKU. +func (m *AzureManager) parseSKUAndVMsAgentpoolNameFromSpecName(name string) (bool, string, string) { + parts := strings.Split(name, "/") + if len(parts) == 2 { + agentPoolName := parts[0] + sku := parts[1] + + vmsPoolMap := m.azureCache.getVMsPoolMap() + if _, ok := vmsPoolMap[agentPoolName]; ok { + return true, agentPoolName, sku + } + } + return false, "", "" +} + func (m *AzureManager) buildNodeGroupFromSpec(spec string) (cloudprovider.NodeGroup, error) { scaleToZeroSupported := scaleToZeroSupportedStandard if strings.EqualFold(m.config.VMType, providerazureconsts.VMTypeVMSS) { @@ -177,9 +194,13 @@ func (m *AzureManager) buildNodeGroupFromSpec(spec string) (cloudprovider.NodeGr if err != nil { return nil, fmt.Errorf("failed to parse node group spec: %v", err) } - vmsPoolSet := m.azureCache.getVMsPoolSet() - if _, ok := vmsPoolSet[s.Name]; ok { - return NewVMsPool(s, m), nil + + // Starting from release 1.30, a cluster may have both VMSS and VMs pools. + // Therefore, we cannot solely rely on the VMType to determine the node group type. + // Instead, we need to check the cache to determine if the agent pool is a VMs pool. + isVMsPool, agentPoolName, sku := m.parseSKUAndVMsAgentpoolNameFromSpecName(s.Name) + if isVMsPool { + return NewVMPool(s, m, agentPoolName, sku) } switch m.config.VMType { diff --git a/cluster-autoscaler/cloudprovider/azure/azure_manager_test.go b/cluster-autoscaler/cloudprovider/azure/azure_manager_test.go index c4a8662414b82fb1fa9eb4908f04d101436e2aa6..6efd60756a4cdafaafac84f2f9872ddb1e6700cf 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_manager_test.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_manager_test.go @@ -297,6 +297,7 @@ func TestCreateAzureManagerValidConfig(t *testing.T) { VmssVmsCacheJitter: 120, MaxDeploymentsCount: 8, EnableFastDeleteOnFailedProvisioning: true, + EnableVMsAgentPool: false, } assert.NoError(t, err) @@ -618,9 +619,14 @@ func TestCreateAzureManagerWithNilConfig(t *testing.T) { mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) mockVMSSClient.EXPECT().List(gomock.Any(), "resourceGroup").Return([]compute.VirtualMachineScaleSet{}, nil).AnyTimes() mockVMClient.EXPECT().List(gomock.Any(), "resourceGroup").Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockAgentpoolclient := NewMockAgentPoolsClient(ctrl) + vmspool := getTestVMsAgentPool(false) + fakeAPListPager := getFakeAgentpoolListPager(&vmspool) + mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).Return(fakeAPListPager).AnyTimes() mockAzClient := &azClient{ virtualMachinesClient: mockVMClient, virtualMachineScaleSetsClient: mockVMSSClient, + agentPoolClient: mockAgentpoolclient, } expectedConfig := &Config{ @@ -702,6 +708,7 @@ func TestCreateAzureManagerWithNilConfig(t *testing.T) { VmssVmsCacheJitter: 90, MaxDeploymentsCount: 8, EnableFastDeleteOnFailedProvisioning: true, + EnableVMsAgentPool: true, } t.Setenv("ARM_CLOUD", "AzurePublicCloud") @@ -735,6 +742,7 @@ func TestCreateAzureManagerWithNilConfig(t *testing.T) { t.Setenv("ARM_CLUSTER_RESOURCE_GROUP", "myrg") t.Setenv("ARM_BASE_URL_FOR_AP_CLIENT", "nodeprovisioner-svc.nodeprovisioner.svc.cluster.local") t.Setenv("AZURE_ENABLE_FAST_DELETE_ON_FAILED_PROVISIONING", "true") + t.Setenv("AZURE_ENABLE_VMS_AGENT_POOLS", "true") t.Run("environment variables correctly set", func(t *testing.T) { manager, err := createAzureManagerInternal(nil, cloudprovider.NodeGroupDiscoveryOptions{}, mockAzClient) diff --git a/cluster-autoscaler/cloudprovider/azure/azure_mock_agentpool_client.go b/cluster-autoscaler/cloudprovider/azure/azure_mock_agentpool_client.go index eaad11f01d4cbe82eea54595695564f6efa429b7..0e63d30b6465d44a93e146f729c36e0e8d7d440e 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_mock_agentpool_client.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_mock_agentpool_client.go @@ -21,7 +21,7 @@ import ( reflect "reflect" runtime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" - armcontainerservice "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4" + armcontainerservice "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5" gomock "go.uber.org/mock/gomock" ) @@ -49,46 +49,60 @@ func (m *MockAgentPoolsClient) EXPECT() *MockAgentPoolsClientMockRecorder { } // BeginCreateOrUpdate mocks base method. -func (m *MockAgentPoolsClient) BeginCreateOrUpdate(arg0 context.Context, arg1, arg2, arg3 string, arg4 armcontainerservice.AgentPool, arg5 *armcontainerservice.AgentPoolsClientBeginCreateOrUpdateOptions) (*runtime.Poller[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse], error) { +func (m *MockAgentPoolsClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName, resourceName, agentPoolName string, parameters armcontainerservice.AgentPool, options *armcontainerservice.AgentPoolsClientBeginCreateOrUpdateOptions) (*runtime.Poller[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse], error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BeginCreateOrUpdate", arg0, arg1, arg2, arg3, arg4, arg5) + ret := m.ctrl.Call(m, "BeginCreateOrUpdate", ctx, resourceGroupName, resourceName, agentPoolName, parameters, options) ret0, _ := ret[0].(*runtime.Poller[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse]) ret1, _ := ret[1].(error) return ret0, ret1 } // BeginCreateOrUpdate indicates an expected call of BeginCreateOrUpdate. -func (mr *MockAgentPoolsClientMockRecorder) BeginCreateOrUpdate(arg0, arg1, arg2, arg3, arg4, arg5 any) *gomock.Call { +func (mr *MockAgentPoolsClientMockRecorder) BeginCreateOrUpdate(ctx, resourceGroupName, resourceName, agentPoolName, parameters, options any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginCreateOrUpdate", reflect.TypeOf((*MockAgentPoolsClient)(nil).BeginCreateOrUpdate), arg0, arg1, arg2, arg3, arg4, arg5) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginCreateOrUpdate", reflect.TypeOf((*MockAgentPoolsClient)(nil).BeginCreateOrUpdate), ctx, resourceGroupName, resourceName, agentPoolName, parameters, options) } // BeginDeleteMachines mocks base method. -func (m *MockAgentPoolsClient) BeginDeleteMachines(arg0 context.Context, arg1, arg2, arg3 string, arg4 armcontainerservice.AgentPoolDeleteMachinesParameter, arg5 *armcontainerservice.AgentPoolsClientBeginDeleteMachinesOptions) (*runtime.Poller[armcontainerservice.AgentPoolsClientDeleteMachinesResponse], error) { +func (m *MockAgentPoolsClient) BeginDeleteMachines(ctx context.Context, resourceGroupName, resourceName, agentPoolName string, machines armcontainerservice.AgentPoolDeleteMachinesParameter, options *armcontainerservice.AgentPoolsClientBeginDeleteMachinesOptions) (*runtime.Poller[armcontainerservice.AgentPoolsClientDeleteMachinesResponse], error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BeginDeleteMachines", arg0, arg1, arg2, arg3, arg4, arg5) + ret := m.ctrl.Call(m, "BeginDeleteMachines", ctx, resourceGroupName, resourceName, agentPoolName, machines, options) ret0, _ := ret[0].(*runtime.Poller[armcontainerservice.AgentPoolsClientDeleteMachinesResponse]) ret1, _ := ret[1].(error) return ret0, ret1 } // BeginDeleteMachines indicates an expected call of BeginDeleteMachines. -func (mr *MockAgentPoolsClientMockRecorder) BeginDeleteMachines(arg0, arg1, arg2, arg3, arg4, arg5 any) *gomock.Call { +func (mr *MockAgentPoolsClientMockRecorder) BeginDeleteMachines(ctx, resourceGroupName, resourceName, agentPoolName, machines, options any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginDeleteMachines", reflect.TypeOf((*MockAgentPoolsClient)(nil).BeginDeleteMachines), arg0, arg1, arg2, arg3, arg4, arg5) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginDeleteMachines", reflect.TypeOf((*MockAgentPoolsClient)(nil).BeginDeleteMachines), ctx, resourceGroupName, resourceName, agentPoolName, machines, options) } // Get mocks base method. -func (m *MockAgentPoolsClient) Get(arg0 context.Context, arg1, arg2, arg3 string, arg4 *armcontainerservice.AgentPoolsClientGetOptions) (armcontainerservice.AgentPoolsClientGetResponse, error) { +func (m *MockAgentPoolsClient) Get(ctx context.Context, resourceGroupName, resourceName, agentPoolName string, options *armcontainerservice.AgentPoolsClientGetOptions) (armcontainerservice.AgentPoolsClientGetResponse, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2, arg3, arg4) + ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, resourceName, agentPoolName, options) ret0, _ := ret[0].(armcontainerservice.AgentPoolsClientGetResponse) ret1, _ := ret[1].(error) return ret0, ret1 } // Get indicates an expected call of Get. -func (mr *MockAgentPoolsClientMockRecorder) Get(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { +func (mr *MockAgentPoolsClientMockRecorder) Get(ctx, resourceGroupName, resourceName, agentPoolName, options any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockAgentPoolsClient)(nil).Get), arg0, arg1, arg2, arg3, arg4) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockAgentPoolsClient)(nil).Get), ctx, resourceGroupName, resourceName, agentPoolName, options) +} + +// NewListPager mocks base method. +func (m *MockAgentPoolsClient) NewListPager(resourceGroupName, resourceName string, options *armcontainerservice.AgentPoolsClientListOptions) *runtime.Pager[armcontainerservice.AgentPoolsClientListResponse] { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewListPager", resourceGroupName, resourceName, options) + ret0, _ := ret[0].(*runtime.Pager[armcontainerservice.AgentPoolsClientListResponse]) + return ret0 +} + +// NewListPager indicates an expected call of NewListPager. +func (mr *MockAgentPoolsClientMockRecorder) NewListPager(resourceGroupName, resourceName, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewListPager", reflect.TypeOf((*MockAgentPoolsClient)(nil).NewListPager), resourceGroupName, resourceName, options) } diff --git a/cluster-autoscaler/cloudprovider/azure/azure_scale_set.go b/cluster-autoscaler/cloudprovider/azure/azure_scale_set.go index cb5d93fffed769a706ed094a6b5fbaf5805f72e2..64a04e77e118720c55587d6fb9940718a4077c26 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_scale_set.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_scale_set.go @@ -651,15 +651,18 @@ func (scaleSet *ScaleSet) Debug() string { // TemplateNodeInfo returns a node template for this scale set. func (scaleSet *ScaleSet) TemplateNodeInfo() (*framework.NodeInfo, error) { - template, err := scaleSet.getVMSSFromCache() + vmss, err := scaleSet.getVMSSFromCache() if err != nil { return nil, err } inputLabels := map[string]string{} inputTaints := "" - node, err := buildNodeFromTemplate(scaleSet.Name, inputLabels, inputTaints, template, scaleSet.manager, scaleSet.enableDynamicInstanceList) - + template, err := buildNodeTemplateFromVMSS(vmss, inputLabels, inputTaints) + if err != nil { + return nil, err + } + node, err := buildNodeFromTemplate(scaleSet.Name, template, scaleSet.manager, scaleSet.enableDynamicInstanceList) if err != nil { return nil, err } diff --git a/cluster-autoscaler/cloudprovider/azure/azure_scale_set_test.go b/cluster-autoscaler/cloudprovider/azure/azure_scale_set_test.go index b78d7f71b2575905aac650404d4627f5b236e755..7d5ae4d263f085cb379409cc789293e3e83f28ef 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_scale_set_test.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_scale_set_test.go @@ -1232,12 +1232,12 @@ func TestScaleSetTemplateNodeInfo(t *testing.T) { // Properly testing dynamic SKU list through skewer is not possible, // because there are no Resource API mocks included yet. // Instead, the rest of the (consumer side) tests here - // override GetVMSSTypeDynamically and GetVMSSTypeStatically functions. + // override GetInstanceTypeDynamically and GetInstanceTypeStatically functions. t.Run("Checking dynamic workflow", func(t *testing.T) { asg.enableDynamicInstanceList = true - GetVMSSTypeDynamically = func(template compute.VirtualMachineScaleSet, azCache *azureCache) (InstanceType, error) { + GetInstanceTypeDynamically = func(template NodeTemplate, azCache *azureCache) (InstanceType, error) { vmssType := InstanceType{} vmssType.VCPU = 1 vmssType.GPU = 2 @@ -1255,10 +1255,10 @@ func TestScaleSetTemplateNodeInfo(t *testing.T) { t.Run("Checking static workflow if dynamic fails", func(t *testing.T) { asg.enableDynamicInstanceList = true - GetVMSSTypeDynamically = func(template compute.VirtualMachineScaleSet, azCache *azureCache) (InstanceType, error) { + GetInstanceTypeDynamically = func(template NodeTemplate, azCache *azureCache) (InstanceType, error) { return InstanceType{}, fmt.Errorf("dynamic error exists") } - GetVMSSTypeStatically = func(template compute.VirtualMachineScaleSet) (*InstanceType, error) { + GetInstanceTypeStatically = func(template NodeTemplate) (*InstanceType, error) { vmssType := InstanceType{} vmssType.VCPU = 1 vmssType.GPU = 2 @@ -1276,10 +1276,10 @@ func TestScaleSetTemplateNodeInfo(t *testing.T) { t.Run("Fails to find vmss instance information using static and dynamic workflow, instance not supported", func(t *testing.T) { asg.enableDynamicInstanceList = true - GetVMSSTypeDynamically = func(template compute.VirtualMachineScaleSet, azCache *azureCache) (InstanceType, error) { + GetInstanceTypeDynamically = func(template NodeTemplate, azCache *azureCache) (InstanceType, error) { return InstanceType{}, fmt.Errorf("dynamic error exists") } - GetVMSSTypeStatically = func(template compute.VirtualMachineScaleSet) (*InstanceType, error) { + GetInstanceTypeStatically = func(template NodeTemplate) (*InstanceType, error) { return &InstanceType{}, fmt.Errorf("static error exists") } nodeInfo, err := asg.TemplateNodeInfo() @@ -1292,7 +1292,7 @@ func TestScaleSetTemplateNodeInfo(t *testing.T) { t.Run("Checking static-only workflow", func(t *testing.T) { asg.enableDynamicInstanceList = false - GetVMSSTypeStatically = func(template compute.VirtualMachineScaleSet) (*InstanceType, error) { + GetInstanceTypeStatically = func(template NodeTemplate) (*InstanceType, error) { vmssType := InstanceType{} vmssType.VCPU = 1 vmssType.GPU = 2 diff --git a/cluster-autoscaler/cloudprovider/azure/azure_template.go b/cluster-autoscaler/cloudprovider/azure/azure_template.go index 9411354be35b7f0d4b805064ea2f7a4d1b5f80da..03277640be36a381844dbdc2cbf36bc2f74391ef 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_template.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_template.go @@ -24,7 +24,9 @@ import ( "strings" "time" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" + "github.com/Azure/go-autorest/autorest/to" apiv1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -84,8 +86,132 @@ const ( clusterLabelKey = AKSLabelKeyPrefixValue + "cluster" ) -func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string, inputTaints string, - template compute.VirtualMachineScaleSet, manager *AzureManager, enableDynamicInstanceList bool) (*apiv1.Node, error) { +// VMPoolNodeTemplate holds properties for node from VMPool +type VMPoolNodeTemplate struct { + AgentPoolName string + Taints []apiv1.Taint + Labels map[string]*string + OSDiskType *armcontainerservice.OSDiskType +} + +// VMSSNodeTemplate holds properties for node from VMSS +type VMSSNodeTemplate struct { + InputLabels map[string]string + InputTaints string + Tags map[string]*string + OSDisk *compute.VirtualMachineScaleSetOSDisk +} + +// NodeTemplate represents a template for an Azure node +type NodeTemplate struct { + SkuName string + InstanceOS string + Location string + Zones []string + VMPoolNodeTemplate *VMPoolNodeTemplate + VMSSNodeTemplate *VMSSNodeTemplate +} + +func buildNodeTemplateFromVMSS(vmss compute.VirtualMachineScaleSet, inputLabels map[string]string, inputTaints string) (NodeTemplate, error) { + instanceOS := cloudprovider.DefaultOS + if vmss.VirtualMachineProfile != nil && + vmss.VirtualMachineProfile.OsProfile != nil && + vmss.VirtualMachineProfile.OsProfile.WindowsConfiguration != nil { + instanceOS = "windows" + } + + var osDisk *compute.VirtualMachineScaleSetOSDisk + if vmss.VirtualMachineProfile != nil && + vmss.VirtualMachineProfile.StorageProfile != nil && + vmss.VirtualMachineProfile.StorageProfile.OsDisk != nil { + osDisk = vmss.VirtualMachineProfile.StorageProfile.OsDisk + } + + if vmss.Sku == nil || vmss.Sku.Name == nil { + return NodeTemplate{}, fmt.Errorf("VMSS %s has no SKU", to.String(vmss.Name)) + } + + if vmss.Location == nil { + return NodeTemplate{}, fmt.Errorf("VMSS %s has no location", to.String(vmss.Name)) + } + + zones := []string{} + if vmss.Zones != nil { + zones = *vmss.Zones + } + + return NodeTemplate{ + SkuName: *vmss.Sku.Name, + + Location: *vmss.Location, + Zones: zones, + InstanceOS: instanceOS, + VMSSNodeTemplate: &VMSSNodeTemplate{ + InputLabels: inputLabels, + InputTaints: inputTaints, + OSDisk: osDisk, + Tags: vmss.Tags, + }, + }, nil +} + +func buildNodeTemplateFromVMPool(vmsPool armcontainerservice.AgentPool, location string, skuName string, labelsFromSpec map[string]string, taintsFromSpec string) (NodeTemplate, error) { + if vmsPool.Properties == nil { + return NodeTemplate{}, fmt.Errorf("vmsPool %s has nil properties", to.String(vmsPool.Name)) + } + // labels from the agentpool + labels := vmsPool.Properties.NodeLabels + // labels from spec + for k, v := range labelsFromSpec { + if labels == nil { + labels = make(map[string]*string) + } + labels[k] = to.StringPtr(v) + } + + // taints from the agentpool + taintsList := []string{} + for _, taint := range vmsPool.Properties.NodeTaints { + if to.String(taint) != "" { + taintsList = append(taintsList, to.String(taint)) + } + } + // taints from spec + if taintsFromSpec != "" { + taintsList = append(taintsList, taintsFromSpec) + } + taintsStr := strings.Join(taintsList, ",") + taints := extractTaintsFromSpecString(taintsStr) + + var zones []string + if vmsPool.Properties.AvailabilityZones != nil { + for _, zone := range vmsPool.Properties.AvailabilityZones { + if zone != nil { + zones = append(zones, *zone) + } + } + } + + var instanceOS string + if vmsPool.Properties.OSType != nil { + instanceOS = strings.ToLower(string(*vmsPool.Properties.OSType)) + } + + return NodeTemplate{ + SkuName: skuName, + Zones: zones, + InstanceOS: instanceOS, + Location: location, + VMPoolNodeTemplate: &VMPoolNodeTemplate{ + AgentPoolName: to.String(vmsPool.Name), + OSDiskType: vmsPool.Properties.OSDiskType, + Taints: taints, + Labels: labels, + }, + }, nil +} + +func buildNodeFromTemplate(nodeGroupName string, template NodeTemplate, manager *AzureManager, enableDynamicInstanceList bool) (*apiv1.Node, error) { node := apiv1.Node{} nodeName := fmt.Sprintf("%s-asg-%d", nodeGroupName, rand.Int63()) @@ -104,28 +230,28 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string, // Fetching SKU information from SKU API if enableDynamicInstanceList is true. var dynamicErr error if enableDynamicInstanceList { - var vmssTypeDynamic InstanceType - klog.V(1).Infof("Fetching instance information for SKU: %s from SKU API", *template.Sku.Name) - vmssTypeDynamic, dynamicErr = GetVMSSTypeDynamically(template, manager.azureCache) + var instanceTypeDynamic InstanceType + klog.V(1).Infof("Fetching instance information for SKU: %s from SKU API", template.SkuName) + instanceTypeDynamic, dynamicErr = GetInstanceTypeDynamically(template, manager.azureCache) if dynamicErr == nil { - vcpu = vmssTypeDynamic.VCPU - gpuCount = vmssTypeDynamic.GPU - memoryMb = vmssTypeDynamic.MemoryMb + vcpu = instanceTypeDynamic.VCPU + gpuCount = instanceTypeDynamic.GPU + memoryMb = instanceTypeDynamic.MemoryMb } else { klog.Errorf("Dynamically fetching of instance information from SKU api failed with error: %v", dynamicErr) } } if !enableDynamicInstanceList || dynamicErr != nil { - klog.V(1).Infof("Falling back to static SKU list for SKU: %s", *template.Sku.Name) + klog.V(1).Infof("Falling back to static SKU list for SKU: %s", template.SkuName) // fall-back on static list of vmss if dynamic workflow fails. - vmssTypeStatic, staticErr := GetVMSSTypeStatically(template) + instanceTypeStatic, staticErr := GetInstanceTypeStatically(template) if staticErr == nil { - vcpu = vmssTypeStatic.VCPU - gpuCount = vmssTypeStatic.GPU - memoryMb = vmssTypeStatic.MemoryMb + vcpu = instanceTypeStatic.VCPU + gpuCount = instanceTypeStatic.GPU + memoryMb = instanceTypeStatic.MemoryMb } else { // return error if neither of the workflows results with vmss data. - klog.V(1).Infof("Instance type %q not supported, err: %v", *template.Sku.Name, staticErr) + klog.V(1).Infof("Instance type %q not supported, err: %v", template.SkuName, staticErr) return nil, staticErr } } @@ -134,7 +260,7 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string, node.Status.Capacity[apiv1.ResourceCPU] = *resource.NewQuantity(vcpu, resource.DecimalSI) // isNPSeries returns if a SKU is an NP-series SKU // SKU API reports GPUs for NP-series but it's actually FPGAs - if isNPSeries(*template.Sku.Name) { + if isNPSeries(template.SkuName) { node.Status.Capacity[xilinxFpgaResourceName] = *resource.NewQuantity(gpuCount, resource.DecimalSI) } else { node.Status.Capacity[gpu.ResourceNvidiaGPU] = *resource.NewQuantity(gpuCount, resource.DecimalSI) @@ -145,9 +271,37 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string, // TODO: set real allocatable. node.Status.Allocatable = node.Status.Capacity + if template.VMSSNodeTemplate != nil { + node = processVMSSTemplate(template, nodeName, node) + } else if template.VMPoolNodeTemplate != nil { + node = processVMPoolTemplate(template, nodeName, node) + } else { + return nil, fmt.Errorf("invalid node template: missing both VMSS and VMPool templates") + } + + klog.V(4).Infof("Setting node %s labels to: %s", nodeName, node.Labels) + klog.V(4).Infof("Setting node %s taints to: %s", nodeName, node.Spec.Taints) + node.Status.Conditions = cloudprovider.BuildReadyConditions() + return &node, nil +} + +func processVMPoolTemplate(template NodeTemplate, nodeName string, node apiv1.Node) apiv1.Node { + labels := buildGenericLabels(template, nodeName) + labels[agentPoolNodeLabelKey] = template.VMPoolNodeTemplate.AgentPoolName + if template.VMPoolNodeTemplate.Labels != nil { + for k, v := range template.VMPoolNodeTemplate.Labels { + labels[k] = to.String(v) + } + } + node.Labels = cloudprovider.JoinStringMaps(node.Labels, labels) + node.Spec.Taints = template.VMPoolNodeTemplate.Taints + return node +} + +func processVMSSTemplate(template NodeTemplate, nodeName string, node apiv1.Node) apiv1.Node { // NodeLabels - if template.Tags != nil { - for k, v := range template.Tags { + if template.VMSSNodeTemplate.Tags != nil { + for k, v := range template.VMSSNodeTemplate.Tags { if v != nil { node.Labels[k] = *v } else { @@ -164,10 +318,10 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string, labels := make(map[string]string) // Prefer the explicit labels in spec coming from RP over the VMSS template - if len(inputLabels) > 0 { - labels = inputLabels + if len(template.VMSSNodeTemplate.InputLabels) > 0 { + labels = template.VMSSNodeTemplate.InputLabels } else { - labels = extractLabelsFromScaleSet(template.Tags) + labels = extractLabelsFromTags(template.VMSSNodeTemplate.Tags) } // Add the agentpool label, its value should come from the VMSS poolName tag @@ -182,87 +336,74 @@ func buildNodeFromTemplate(nodeGroupName string, inputLabels map[string]string, labels[agentPoolNodeLabelKey] = node.Labels[poolNameTag] } - // Add the storage profile and storage tier labels - if template.VirtualMachineProfile != nil && template.VirtualMachineProfile.StorageProfile != nil && template.VirtualMachineProfile.StorageProfile.OsDisk != nil { + // Add the storage profile and storage tier labels for vmss node + if template.VMSSNodeTemplate.OSDisk != nil { // ephemeral - if template.VirtualMachineProfile.StorageProfile.OsDisk.DiffDiskSettings != nil && template.VirtualMachineProfile.StorageProfile.OsDisk.DiffDiskSettings.Option == compute.Local { + if template.VMSSNodeTemplate.OSDisk.DiffDiskSettings != nil && template.VMSSNodeTemplate.OSDisk.DiffDiskSettings.Option == compute.Local { labels[legacyStorageProfileNodeLabelKey] = "ephemeral" labels[storageProfileNodeLabelKey] = "ephemeral" } else { labels[legacyStorageProfileNodeLabelKey] = "managed" labels[storageProfileNodeLabelKey] = "managed" } - if template.VirtualMachineProfile.StorageProfile.OsDisk.ManagedDisk != nil { - labels[legacyStorageTierNodeLabelKey] = string(template.VirtualMachineProfile.StorageProfile.OsDisk.ManagedDisk.StorageAccountType) - labels[storageTierNodeLabelKey] = string(template.VirtualMachineProfile.StorageProfile.OsDisk.ManagedDisk.StorageAccountType) + if template.VMSSNodeTemplate.OSDisk.ManagedDisk != nil { + labels[legacyStorageTierNodeLabelKey] = string(template.VMSSNodeTemplate.OSDisk.ManagedDisk.StorageAccountType) + labels[storageTierNodeLabelKey] = string(template.VMSSNodeTemplate.OSDisk.ManagedDisk.StorageAccountType) } // Add ephemeral-storage value - if template.VirtualMachineProfile.StorageProfile.OsDisk.DiskSizeGB != nil { - node.Status.Capacity[apiv1.ResourceEphemeralStorage] = *resource.NewQuantity(int64(int(*template.VirtualMachineProfile.StorageProfile.OsDisk.DiskSizeGB)*1024*1024*1024), resource.DecimalSI) - klog.V(4).Infof("OS Disk Size from template is: %d", *template.VirtualMachineProfile.StorageProfile.OsDisk.DiskSizeGB) + if template.VMSSNodeTemplate.OSDisk.DiskSizeGB != nil { + node.Status.Capacity[apiv1.ResourceEphemeralStorage] = *resource.NewQuantity(int64(int(*template.VMSSNodeTemplate.OSDisk.DiskSizeGB)*1024*1024*1024), resource.DecimalSI) + klog.V(4).Infof("OS Disk Size from template is: %d", *template.VMSSNodeTemplate.OSDisk.DiskSizeGB) klog.V(4).Infof("Setting ephemeral storage to: %v", node.Status.Capacity[apiv1.ResourceEphemeralStorage]) } } // If we are on GPU-enabled SKUs, append the accelerator // label so that CA makes better decision when scaling from zero for GPU pools - if isNvidiaEnabledSKU(*template.Sku.Name) { + if isNvidiaEnabledSKU(template.SkuName) { labels[GPULabel] = "nvidia" labels[legacyGPULabel] = "nvidia" } // Extract allocatables from tags - resourcesFromTags := extractAllocatableResourcesFromScaleSet(template.Tags) + resourcesFromTags := extractAllocatableResourcesFromScaleSet(template.VMSSNodeTemplate.Tags) for resourceName, val := range resourcesFromTags { node.Status.Capacity[apiv1.ResourceName(resourceName)] = *val } node.Labels = cloudprovider.JoinStringMaps(node.Labels, labels) - klog.V(4).Infof("Setting node %s labels to: %s", nodeName, node.Labels) var taints []apiv1.Taint - // Prefer the explicit taints in spec over the VMSS template - if inputTaints != "" { - taints = extractTaintsFromSpecString(inputTaints) + // Prefer the explicit taints in spec over the tags from vmss or vm + if template.VMSSNodeTemplate.InputTaints != "" { + taints = extractTaintsFromSpecString(template.VMSSNodeTemplate.InputTaints) } else { - taints = extractTaintsFromScaleSet(template.Tags) + taints = extractTaintsFromTags(template.VMSSNodeTemplate.Tags) } // Taints from the Scale Set's Tags node.Spec.Taints = taints - klog.V(4).Infof("Setting node %s taints to: %s", nodeName, node.Spec.Taints) - - node.Status.Conditions = cloudprovider.BuildReadyConditions() - return &node, nil + return node } -func buildInstanceOS(template compute.VirtualMachineScaleSet) string { - instanceOS := cloudprovider.DefaultOS - if template.VirtualMachineProfile != nil && template.VirtualMachineProfile.OsProfile != nil && template.VirtualMachineProfile.OsProfile.WindowsConfiguration != nil { - instanceOS = "windows" - } - - return instanceOS -} - -func buildGenericLabels(template compute.VirtualMachineScaleSet, nodeName string) map[string]string { +func buildGenericLabels(template NodeTemplate, nodeName string) map[string]string { result := make(map[string]string) result[kubeletapis.LabelArch] = cloudprovider.DefaultArch result[apiv1.LabelArchStable] = cloudprovider.DefaultArch - result[kubeletapis.LabelOS] = buildInstanceOS(template) - result[apiv1.LabelOSStable] = buildInstanceOS(template) + result[kubeletapis.LabelOS] = template.InstanceOS + result[apiv1.LabelOSStable] = template.InstanceOS - result[apiv1.LabelInstanceType] = *template.Sku.Name - result[apiv1.LabelInstanceTypeStable] = *template.Sku.Name - result[apiv1.LabelZoneRegion] = strings.ToLower(*template.Location) - result[apiv1.LabelTopologyRegion] = strings.ToLower(*template.Location) + result[apiv1.LabelInstanceType] = template.SkuName + result[apiv1.LabelInstanceTypeStable] = template.SkuName + result[apiv1.LabelZoneRegion] = strings.ToLower(template.Location) + result[apiv1.LabelTopologyRegion] = strings.ToLower(template.Location) - if template.Zones != nil && len(*template.Zones) > 0 { - failureDomains := make([]string, len(*template.Zones)) - for k, v := range *template.Zones { - failureDomains[k] = strings.ToLower(*template.Location) + "-" + v + if len(template.Zones) > 0 { + failureDomains := make([]string, len(template.Zones)) + for k, v := range template.Zones { + failureDomains[k] = strings.ToLower(template.Location) + "-" + v } //Picks random zones for Multi-zone nodepool when scaling from zero. //This random zone will not be the same as the zone of the VMSS that is being created, the purpose of creating @@ -283,7 +424,7 @@ func buildGenericLabels(template compute.VirtualMachineScaleSet, nodeName string return result } -func extractLabelsFromScaleSet(tags map[string]*string) map[string]string { +func extractLabelsFromTags(tags map[string]*string) map[string]string { result := make(map[string]string) for tagName, tagValue := range tags { @@ -300,7 +441,7 @@ func extractLabelsFromScaleSet(tags map[string]*string) map[string]string { return result } -func extractTaintsFromScaleSet(tags map[string]*string) []apiv1.Taint { +func extractTaintsFromTags(tags map[string]*string) []apiv1.Taint { taints := make([]apiv1.Taint, 0) for tagName, tagValue := range tags { @@ -327,35 +468,61 @@ func extractTaintsFromScaleSet(tags map[string]*string) []apiv1.Taint { return taints } +// extractTaintsFromSpecString is for nodepool taints // Example of a valid taints string, is the same argument to kubelet's `--register-with-taints` // "dedicated=foo:NoSchedule,group=bar:NoExecute,app=fizz:PreferNoSchedule" func extractTaintsFromSpecString(taintsString string) []apiv1.Taint { taints := make([]apiv1.Taint, 0) + dedupMap := make(map[string]interface{}) // First split the taints at the separator splits := strings.Split(taintsString, ",") for _, split := range splits { - taintSplit := strings.Split(split, "=") - if len(taintSplit) != 2 { + if dedupMap[split] != nil { continue } + dedupMap[split] = struct{}{} + valid, taint := constructTaintFromString(split) + if valid { + taints = append(taints, taint) + } + } + return taints +} - taintKey := taintSplit[0] - taintValue := taintSplit[1] - - r, _ := regexp.Compile("(.*):(?:NoSchedule|NoExecute|PreferNoSchedule)") - if !r.MatchString(taintValue) { - continue +// buildNodeTaintsForVMPool is for VMPool taints, it looks for the taints in the format +// []string{zone=dmz:NoSchedule, usage=monitoring:NoSchedule} +func buildNodeTaintsForVMPool(taintStrs []string) []apiv1.Taint { + taints := make([]apiv1.Taint, 0) + for _, taintStr := range taintStrs { + valid, taint := constructTaintFromString(taintStr) + if valid { + taints = append(taints, taint) } + } + return taints +} - values := strings.SplitN(taintValue, ":", 2) - taints = append(taints, apiv1.Taint{ - Key: taintKey, - Value: values[0], - Effect: apiv1.TaintEffect(values[1]), - }) +// constructTaintFromString constructs a taint from a string in the format <key>=<value>:<effect> +// if the input string is not in the correct format, it returns false and an empty taint +func constructTaintFromString(taintString string) (bool, apiv1.Taint) { + taintSplit := strings.Split(taintString, "=") + if len(taintSplit) != 2 { + return false, apiv1.Taint{} } + taintKey := taintSplit[0] + taintValue := taintSplit[1] - return taints + r, _ := regexp.Compile("(.*):(?:NoSchedule|NoExecute|PreferNoSchedule)") + if !r.MatchString(taintValue) { + return false, apiv1.Taint{} + } + + values := strings.SplitN(taintValue, ":", 2) + return true, apiv1.Taint{ + Key: taintKey, + Value: values[0], + Effect: apiv1.TaintEffect(values[1]), + } } func extractAutoscalingOptionsFromScaleSetTags(tags map[string]*string) map[string]string { diff --git a/cluster-autoscaler/cloudprovider/azure/azure_template_test.go b/cluster-autoscaler/cloudprovider/azure/azure_template_test.go index 3711a11ce37231dd989e23bdb3c0e5890a7ba7df..2cb327f4437acc5d730f4ee68fd2627d5c4b7bcb 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_template_test.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_template_test.go @@ -21,6 +21,7 @@ import ( "strings" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/to" @@ -30,7 +31,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" ) -func TestExtractLabelsFromScaleSet(t *testing.T) { +func TestExtractLabelsFromTags(t *testing.T) { expectedNodeLabelKey := "zip" expectedNodeLabelValue := "zap" extraNodeLabelValue := "buzz" @@ -52,14 +53,14 @@ func TestExtractLabelsFromScaleSet(t *testing.T) { fmt.Sprintf("%s%s", nodeLabelTagName, escapedUnderscoreNodeLabelKey): &escapedUnderscoreNodeLabelValue, } - labels := extractLabelsFromScaleSet(tags) + labels := extractLabelsFromTags(tags) assert.Len(t, labels, 3) assert.Equal(t, expectedNodeLabelValue, labels[expectedNodeLabelKey]) assert.Equal(t, escapedSlashNodeLabelValue, labels[expectedSlashEscapedNodeLabelKey]) assert.Equal(t, escapedUnderscoreNodeLabelValue, labels[expectedUnderscoreEscapedNodeLabelKey]) } -func TestExtractTaintsFromScaleSet(t *testing.T) { +func TestExtractTaintsFromTags(t *testing.T) { noScheduleTaintValue := "foo:NoSchedule" noExecuteTaintValue := "bar:NoExecute" preferNoScheduleTaintValue := "fizz:PreferNoSchedule" @@ -100,7 +101,7 @@ func TestExtractTaintsFromScaleSet(t *testing.T) { }, } - taints := extractTaintsFromScaleSet(tags) + taints := extractTaintsFromTags(tags) assert.Len(t, taints, 4) assert.Equal(t, makeTaintSet(expectedTaints), makeTaintSet(taints)) } @@ -137,6 +138,11 @@ func TestExtractTaintsFromSpecString(t *testing.T) { Value: "fizz", Effect: apiv1.TaintEffectPreferNoSchedule, }, + { + Key: "dedicated", // duplicate key, should be ignored + Value: "foo", + Effect: apiv1.TaintEffectNoSchedule, + }, } taints := extractTaintsFromSpecString(strings.Join(taintsString, ",")) @@ -176,8 +182,9 @@ func TestTopologyFromScaleSet(t *testing.T) { Location: to.StringPtr("westus"), } expectedZoneValues := []string{"westus-1", "westus-2", "westus-3"} - - labels := buildGenericLabels(testVmss, testNodeName) + template, err := buildNodeTemplateFromVMSS(testVmss, map[string]string{}, "") + assert.NoError(t, err) + labels := buildGenericLabels(template, testNodeName) failureDomain, ok := labels[apiv1.LabelZoneFailureDomain] assert.True(t, ok) topologyZone, ok := labels[apiv1.LabelTopologyZone] @@ -205,7 +212,9 @@ func TestEmptyTopologyFromScaleSet(t *testing.T) { expectedFailureDomain := "0" expectedTopologyZone := "0" expectedAzureDiskTopology := "" - labels := buildGenericLabels(testVmss, testNodeName) + template, err := buildNodeTemplateFromVMSS(testVmss, map[string]string{}, "") + assert.NoError(t, err) + labels := buildGenericLabels(template, testNodeName) failureDomain, ok := labels[apiv1.LabelZoneFailureDomain] assert.True(t, ok) @@ -219,6 +228,61 @@ func TestEmptyTopologyFromScaleSet(t *testing.T) { assert.True(t, ok) assert.Equal(t, expectedAzureDiskTopology, azureDiskTopology) } +func TestBuildNodeTemplateFromVMPool(t *testing.T) { + agentPoolName := "testpool" + location := "eastus" + skuName := "Standard_DS2_v2" + labelKey := "foo" + labelVal := "bar" + taintStr := "dedicated=foo:NoSchedule,boo=fizz:PreferNoSchedule,group=bar:NoExecute" + + osType := armcontainerservice.OSTypeLinux + osDiskType := armcontainerservice.OSDiskTypeEphemeral + zone1 := "1" + zone2 := "2" + + vmpool := armcontainerservice.AgentPool{ + Name: to.StringPtr(agentPoolName), + Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{ + NodeLabels: map[string]*string{ + "existing": to.StringPtr("label"), + "department": to.StringPtr("engineering"), + }, + NodeTaints: []*string{to.StringPtr("group=bar:NoExecute")}, + OSType: &osType, + OSDiskType: &osDiskType, + AvailabilityZones: []*string{&zone1, &zone2}, + }, + } + + labelsFromSpec := map[string]string{labelKey: labelVal} + taintsFromSpec := taintStr + + template, err := buildNodeTemplateFromVMPool(vmpool, location, skuName, labelsFromSpec, taintsFromSpec) + assert.NoError(t, err) + assert.Equal(t, skuName, template.SkuName) + assert.Equal(t, location, template.Location) + assert.ElementsMatch(t, []string{zone1, zone2}, template.Zones) + assert.Equal(t, "linux", template.InstanceOS) + assert.NotNil(t, template.VMPoolNodeTemplate) + assert.Equal(t, agentPoolName, template.VMPoolNodeTemplate.AgentPoolName) + assert.Equal(t, &osDiskType, template.VMPoolNodeTemplate.OSDiskType) + // Labels: should include both from NodeLabels and labelsFromSpec + assert.Contains(t, template.VMPoolNodeTemplate.Labels, "existing") + assert.Equal(t, "label", *template.VMPoolNodeTemplate.Labels["existing"]) + assert.Contains(t, template.VMPoolNodeTemplate.Labels, "department") + assert.Equal(t, "engineering", *template.VMPoolNodeTemplate.Labels["department"]) + assert.Contains(t, template.VMPoolNodeTemplate.Labels, labelKey) + assert.Equal(t, labelVal, *template.VMPoolNodeTemplate.Labels[labelKey]) + // Taints: should include both from NodeTaints and taintsFromSpec + taintSet := makeTaintSet(template.VMPoolNodeTemplate.Taints) + expectedTaints := []apiv1.Taint{ + {Key: "group", Value: "bar", Effect: apiv1.TaintEffectNoExecute}, + {Key: "dedicated", Value: "foo", Effect: apiv1.TaintEffectNoSchedule}, + {Key: "boo", Value: "fizz", Effect: apiv1.TaintEffectPreferNoSchedule}, + } + assert.Equal(t, makeTaintSet(expectedTaints), taintSet) +} func makeTaintSet(taints []apiv1.Taint) map[apiv1.Taint]bool { set := make(map[apiv1.Taint]bool) diff --git a/cluster-autoscaler/cloudprovider/azure/azure_vms_pool.go b/cluster-autoscaler/cloudprovider/azure/azure_vms_pool.go index 34838c4e3aba258faaddb40949270e9f63af6a93..21aff31d3eeae3bf93b6a922186a552af56effeb 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_vms_pool.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_vms_pool.go @@ -18,142 +18,426 @@ package azure import ( "fmt" + "net/http" + "strings" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" + "github.com/Azure/go-autorest/autorest/to" + apiv1 "k8s.io/api/core/v1" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider" "k8s.io/autoscaler/cluster-autoscaler/config" "k8s.io/autoscaler/cluster-autoscaler/config/dynamic" "k8s.io/autoscaler/cluster-autoscaler/simulator/framework" + klog "k8s.io/klog/v2" ) -// VMsPool is single instance VM pool -// this is a placeholder for now, no real implementation -type VMsPool struct { +// VMPool represents a group of standalone virtual machines (VMs) with a single SKU. +// It is part of a mixed-SKU agent pool (an agent pool with type `VirtualMachines`). +// Terminology: +// - Agent pool: A node pool in an AKS cluster. +// - VMs pool: An agent pool of type `VirtualMachines`, which can contain mixed SKUs. +// - VMPool: A subset of VMs within a VMs pool that share the same SKU. +type VMPool struct { azureRef manager *AzureManager - resourceGroup string + agentPoolName string // the virtual machines agentpool that this VMPool belongs to + sku string // sku of the VM in the pool minSize int maxSize int - - curSize int64 - // sizeMutex sync.Mutex - // lastSizeRefresh time.Time } -// NewVMsPool creates a new VMsPool -func NewVMsPool(spec *dynamic.NodeGroupSpec, am *AzureManager) *VMsPool { - nodepool := &VMsPool{ +// NewVMPool creates a new VMPool - a pool of standalone VMs of a single size. +func NewVMPool(spec *dynamic.NodeGroupSpec, am *AzureManager, agentPoolName string, sku string) (*VMPool, error) { + if am.azClient.agentPoolClient == nil { + return nil, fmt.Errorf("agentPoolClient is nil") + } + + nodepool := &VMPool{ azureRef: azureRef{ - Name: spec.Name, + Name: spec.Name, // in format "<agentPoolName>/<sku>" }, - manager: am, - resourceGroup: am.config.ResourceGroup, - - curSize: -1, - minSize: spec.MinSize, - maxSize: spec.MaxSize, + sku: sku, + agentPoolName: agentPoolName, + minSize: spec.MinSize, + maxSize: spec.MaxSize, } - - return nodepool + return nodepool, nil } -// MinSize returns the minimum size the cluster is allowed to scaled down +// MinSize returns the minimum size the vmPool is allowed to scaled down // to as provided by the node spec in --node parameter. -func (agentPool *VMsPool) MinSize() int { - return agentPool.minSize +func (vmPool *VMPool) MinSize() int { + return vmPool.minSize } -// Exist is always true since we are initialized with an existing agentpool -func (agentPool *VMsPool) Exist() bool { +// Exist is always true since we are initialized with an existing vmPool +func (vmPool *VMPool) Exist() bool { return true } // Create creates the node group on the cloud provider side. -func (agentPool *VMsPool) Create() (cloudprovider.NodeGroup, error) { +func (vmPool *VMPool) Create() (cloudprovider.NodeGroup, error) { return nil, cloudprovider.ErrAlreadyExist } // Delete deletes the node group on the cloud provider side. -func (agentPool *VMsPool) Delete() error { +func (vmPool *VMPool) Delete() error { + return cloudprovider.ErrNotImplemented +} + +// ForceDeleteNodes deletes nodes from the group regardless of constraints. +func (vmPool *VMPool) ForceDeleteNodes(nodes []*apiv1.Node) error { return cloudprovider.ErrNotImplemented } // Autoprovisioned is always false since we are initialized with an existing agentpool -func (agentPool *VMsPool) Autoprovisioned() bool { +func (vmPool *VMPool) Autoprovisioned() bool { return false } // GetOptions returns NodeGroupAutoscalingOptions that should be used for this particular // NodeGroup. Returning a nil will result in using default options. -func (agentPool *VMsPool) GetOptions(defaults config.NodeGroupAutoscalingOptions) (*config.NodeGroupAutoscalingOptions, error) { - // TODO(wenxuan): Implement this method - return nil, cloudprovider.ErrNotImplemented +func (vmPool *VMPool) GetOptions(defaults config.NodeGroupAutoscalingOptions) (*config.NodeGroupAutoscalingOptions, error) { + // TODO(wenxuan): implement this method when vmPool can fully support GPU nodepool + return nil, nil } // MaxSize returns the maximum size scale limit provided by --node // parameter to the autoscaler main -func (agentPool *VMsPool) MaxSize() int { - return agentPool.maxSize +func (vmPool *VMPool) MaxSize() int { + return vmPool.maxSize } -// TargetSize returns the current TARGET size of the node group. It is possible that the -// number is different from the number of nodes registered in Kubernetes. -func (agentPool *VMsPool) TargetSize() (int, error) { - // TODO(wenxuan): Implement this method - return -1, cloudprovider.ErrNotImplemented +// TargetSize returns the current target size of the node group. This value represents +// the desired number of nodes in the VMPool, which may differ from the actual number +// of nodes currently present. +func (vmPool *VMPool) TargetSize() (int, error) { + // VMs in the "Deleting" state are not counted towards the target size. + size, err := vmPool.getCurSize(skipOption{skipDeleting: true, skipFailed: false}) + return int(size), err } -// IncreaseSize increase the size through a PUT AP call. It calculates the expected size -// based on a delta provided as parameter -func (agentPool *VMsPool) IncreaseSize(delta int) error { - // TODO(wenxuan): Implement this method - return cloudprovider.ErrNotImplemented +// IncreaseSize increases the size of the VMPool by sending a PUT request to update the agent pool. +// This method waits until the asynchronous PUT operation completes or the client-side timeout is reached. +func (vmPool *VMPool) IncreaseSize(delta int) error { + if delta <= 0 { + return fmt.Errorf("size increase must be positive, current delta: %d", delta) + } + + // Skip VMs in the failed state so that a PUT AP will be triggered to fix the failed VMs. + currentSize, err := vmPool.getCurSize(skipOption{skipDeleting: true, skipFailed: true}) + if err != nil { + return err + } + + if int(currentSize)+delta > vmPool.MaxSize() { + return fmt.Errorf("size-increasing request of %d is bigger than max size %d", int(currentSize)+delta, vmPool.MaxSize()) + } + + updateCtx, cancel := getContextWithTimeout(vmsAsyncContextTimeout) + defer cancel() + + versionedAP, err := vmPool.getAgentpoolFromCache() + if err != nil { + klog.Errorf("Failed to get vmPool %s, error: %s", vmPool.agentPoolName, err) + return err + } + + count := currentSize + int32(delta) + requestBody := armcontainerservice.AgentPool{} + // self-hosted CAS will be using Manual scale profile + if len(versionedAP.Properties.VirtualMachinesProfile.Scale.Manual) > 0 { + requestBody = buildRequestBodyForScaleUp(versionedAP, count, vmPool.sku) + + } else { // AKS-managed CAS will use custom header for setting the target count + header := make(http.Header) + header.Set("Target-Count", fmt.Sprintf("%d", count)) + updateCtx = policy.WithHTTPHeader(updateCtx, header) + } + + defer vmPool.manager.invalidateCache() + poller, err := vmPool.manager.azClient.agentPoolClient.BeginCreateOrUpdate( + updateCtx, + vmPool.manager.config.ClusterResourceGroup, + vmPool.manager.config.ClusterName, + vmPool.agentPoolName, + requestBody, nil) + + if err != nil { + klog.Errorf("Failed to scale up agentpool %s in cluster %s for vmPool %s with error: %v", + vmPool.agentPoolName, vmPool.manager.config.ClusterName, vmPool.Name, err) + return err + } + + if _, err := poller.PollUntilDone(updateCtx, nil /*default polling interval is 30s*/); err != nil { + klog.Errorf("agentPoolClient.BeginCreateOrUpdate for aks cluster %s agentpool %s for scaling up vmPool %s failed with error %s", + vmPool.manager.config.ClusterName, vmPool.agentPoolName, vmPool.Name, err) + return err + } + + klog.Infof("Successfully scaled up agentpool %s in cluster %s for vmPool %s to size %d", + vmPool.agentPoolName, vmPool.manager.config.ClusterName, vmPool.Name, count) + return nil } -// DeleteNodes extracts the providerIDs from the node spec and -// delete or deallocate the nodes from the agent pool based on the scale down policy. -func (agentPool *VMsPool) DeleteNodes(nodes []*apiv1.Node) error { - // TODO(wenxuan): Implement this method - return cloudprovider.ErrNotImplemented +// buildRequestBodyForScaleUp builds the request body for scale up for self-hosted CAS +func buildRequestBodyForScaleUp(agentpool armcontainerservice.AgentPool, count int32, vmSku string) armcontainerservice.AgentPool { + requestBody := armcontainerservice.AgentPool{ + Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{ + Type: agentpool.Properties.Type, + }, + } + + // the request body must have the same mode as the original agentpool + // otherwise the PUT request will fail + if agentpool.Properties.Mode != nil && + *agentpool.Properties.Mode == armcontainerservice.AgentPoolModeSystem { + systemMode := armcontainerservice.AgentPoolModeSystem + requestBody.Properties.Mode = &systemMode + } + + // set the count of the matching manual scale profile to the new target value + for _, manualProfile := range agentpool.Properties.VirtualMachinesProfile.Scale.Manual { + if manualProfile != nil && len(manualProfile.Sizes) == 1 && + strings.EqualFold(to.String(manualProfile.Sizes[0]), vmSku) { + klog.V(5).Infof("Found matching manual profile for VM SKU: %s, updating count to: %d", vmSku, count) + manualProfile.Count = to.Int32Ptr(count) + requestBody.Properties.VirtualMachinesProfile = agentpool.Properties.VirtualMachinesProfile + break + } + } + return requestBody } -// ForceDeleteNodes deletes nodes from the group regardless of constraints. -func (agentPool *VMsPool) ForceDeleteNodes(nodes []*apiv1.Node) error { - return cloudprovider.ErrNotImplemented +// DeleteNodes removes the specified nodes from the VMPool by extracting their providerIDs +// and performing the appropriate delete or deallocate operation based on the agent pool's +// scale-down policy. This method waits for the asynchronous delete operation to complete, +// with a client-side timeout. +func (vmPool *VMPool) DeleteNodes(nodes []*apiv1.Node) error { + // Ensure we don't scale below the minimum size by excluding VMs in the "Deleting" state. + currentSize, err := vmPool.getCurSize(skipOption{skipDeleting: true, skipFailed: false}) + if err != nil { + return fmt.Errorf("unable to retrieve current size: %w", err) + } + + if int(currentSize) <= vmPool.MinSize() { + return fmt.Errorf("cannot delete nodes as minimum size of %d has been reached", vmPool.MinSize()) + } + + providerIDs, err := vmPool.getProviderIDsForNodes(nodes) + if err != nil { + return fmt.Errorf("failed to retrieve provider IDs for nodes: %w", err) + } + + if len(providerIDs) == 0 { + return nil + } + + klog.V(3).Infof("Deleting nodes from vmPool %s: %v", vmPool.Name, providerIDs) + + machineNames := make([]*string, len(providerIDs)) + for i, providerID := range providerIDs { + // extract the machine name from the providerID by splitting the providerID by '/' and get the last element + // The providerID look like this: + // "azure:///subscriptions/0000000-0000-0000-0000-00000000000/resourceGroups/mc_myrg_mycluster_eastus/providers/Microsoft.Compute/virtualMachines/aks-mypool-12345678-vms0" + machineName, err := resourceName(providerID) + if err != nil { + return err + } + machineNames[i] = &machineName + } + + requestBody := armcontainerservice.AgentPoolDeleteMachinesParameter{ + MachineNames: machineNames, + } + + deleteCtx, cancel := getContextWithTimeout(vmsAsyncContextTimeout) + defer cancel() + defer vmPool.manager.invalidateCache() + + poller, err := vmPool.manager.azClient.agentPoolClient.BeginDeleteMachines( + deleteCtx, + vmPool.manager.config.ClusterResourceGroup, + vmPool.manager.config.ClusterName, + vmPool.agentPoolName, + requestBody, nil) + if err != nil { + klog.Errorf("Failed to delete nodes from agentpool %s in cluster %s with error: %v", + vmPool.agentPoolName, vmPool.manager.config.ClusterName, err) + return err + } + + if _, err := poller.PollUntilDone(deleteCtx, nil); err != nil { + klog.Errorf("agentPoolClient.BeginDeleteMachines for aks cluster %s for scaling down vmPool %s failed with error %s", + vmPool.manager.config.ClusterName, vmPool.agentPoolName, err) + return err + } + klog.Infof("Successfully deleted %d nodes from vmPool %s", len(providerIDs), vmPool.Name) + return nil +} + +func (vmPool *VMPool) getProviderIDsForNodes(nodes []*apiv1.Node) ([]string, error) { + var providerIDs []string + for _, node := range nodes { + belongs, err := vmPool.Belongs(node) + if err != nil { + return nil, fmt.Errorf("failed to check if node %s belongs to vmPool %s: %w", node.Name, vmPool.Name, err) + } + if !belongs { + return nil, fmt.Errorf("node %s does not belong to vmPool %s", node.Name, vmPool.Name) + } + providerIDs = append(providerIDs, node.Spec.ProviderID) + } + return providerIDs, nil +} + +// Belongs returns true if the given k8s node belongs to this vms nodepool. +func (vmPool *VMPool) Belongs(node *apiv1.Node) (bool, error) { + klog.V(6).Infof("Check if node belongs to this vmPool:%s, node:%v\n", vmPool, node) + + ref := &azureRef{ + Name: node.Spec.ProviderID, + } + + nodeGroup, err := vmPool.manager.GetNodeGroupForInstance(ref) + if err != nil { + return false, err + } + if nodeGroup == nil { + return false, fmt.Errorf("%s doesn't belong to a known node group", node.Name) + } + if !strings.EqualFold(nodeGroup.Id(), vmPool.Id()) { + return false, nil + } + return true, nil } // DecreaseTargetSize decreases the target size of the node group. -func (agentPool *VMsPool) DecreaseTargetSize(delta int) error { - // TODO(wenxuan): Implement this method - return cloudprovider.ErrNotImplemented +func (vmPool *VMPool) DecreaseTargetSize(delta int) error { + // The TargetSize of a VMPool is automatically adjusted after node deletions. + // This method is invoked in scenarios such as (see details in clusterstate.go): + // - len(readiness.Registered) > acceptableRange.CurrentTarget + // - len(readiness.Registered) < acceptableRange.CurrentTarget - unregisteredNodes + + // For VMPool, this method should not be called because: + // CurrentTarget = len(readiness.Registered) + unregisteredNodes - len(nodesInDeletingState) + // Here, nodesInDeletingState is a subset of unregisteredNodes, + // ensuring len(readiness.Registered) is always within the acceptable range. + + // here we just invalidate the cache to avoid any potential bugs + vmPool.manager.invalidateCache() + klog.Warningf("DecreaseTargetSize called for VMPool %s, but it should not be used, invalidating cache", vmPool.Name) + return nil } -// Id returns the name of the agentPool -func (agentPool *VMsPool) Id() string { - return agentPool.azureRef.Name +// Id returns the name of the agentPool, it is in the format of <agentpoolname>/<sku> +// e.g. mypool1/Standard_D2s_v3 +func (vmPool *VMPool) Id() string { + return vmPool.azureRef.Name } // Debug returns a string with basic details of the agentPool -func (agentPool *VMsPool) Debug() string { - return fmt.Sprintf("%s (%d:%d)", agentPool.Id(), agentPool.MinSize(), agentPool.MaxSize()) +func (vmPool *VMPool) Debug() string { + return fmt.Sprintf("%s (%d:%d)", vmPool.Id(), vmPool.MinSize(), vmPool.MaxSize()) +} + +func isSpotAgentPool(ap armcontainerservice.AgentPool) bool { + if ap.Properties != nil && ap.Properties.ScaleSetPriority != nil { + return strings.EqualFold(string(*ap.Properties.ScaleSetPriority), "Spot") + } + return false +} + +// skipOption is used to determine whether to skip VMs in certain states when calculating the current size of the vmPool. +type skipOption struct { + // skipDeleting indicates whether to skip VMs in the "Deleting" state. + skipDeleting bool + // skipFailed indicates whether to skip VMs in the "Failed" state. + skipFailed bool +} + +// getCurSize determines the current count of VMs in the vmPool, including unregistered ones. +// The source of truth depends on the pool type (spot or non-spot). +func (vmPool *VMPool) getCurSize(op skipOption) (int32, error) { + agentPool, err := vmPool.getAgentpoolFromCache() + if err != nil { + klog.Errorf("Failed to retrieve agent pool %s from cache: %v", vmPool.agentPoolName, err) + return -1, err + } + + // spot pool size is retrieved directly from Azure instead of the cache + if isSpotAgentPool(agentPool) { + return vmPool.getSpotPoolSize() + } + + // non-spot pool size is retrieved from the cache + vms, err := vmPool.getVMsFromCache(op) + if err != nil { + klog.Errorf("Failed to get VMs from cache for agentpool %s with error: %v", vmPool.agentPoolName, err) + return -1, err + } + return int32(len(vms)), nil } -func (agentPool *VMsPool) getVMsFromCache() ([]compute.VirtualMachine, error) { - // vmsPoolMap is a map of agent pool name to the list of virtual machines - vmsPoolMap := agentPool.manager.azureCache.getVirtualMachines() - if _, ok := vmsPoolMap[agentPool.Name]; !ok { - return []compute.VirtualMachine{}, fmt.Errorf("vms pool %s not found in the cache", agentPool.Name) +// getSpotPoolSize retrieves the current size of a spot agent pool directly from Azure. +func (vmPool *VMPool) getSpotPoolSize() (int32, error) { + ap, err := vmPool.getAgentpoolFromAzure() + if err != nil { + klog.Errorf("Failed to get agentpool %s from Azure with error: %v", vmPool.agentPoolName, err) + return -1, err + } + + if ap.Properties != nil { + // the VirtualMachineNodesStatus returned by AKS-RP is constructed from the vm list returned from CRP. + // it only contains VMs in the running state. + for _, status := range ap.Properties.VirtualMachineNodesStatus { + if status != nil { + if strings.EqualFold(to.String(status.Size), vmPool.sku) { + return to.Int32(status.Count), nil + } + } + } } + return -1, fmt.Errorf("failed to get the size of spot agentpool %s", vmPool.agentPoolName) +} + +// getVMsFromCache retrieves the list of virtual machines in this VMPool. +// If excludeDeleting is true, it skips VMs in the "Deleting" state. +// https://learn.microsoft.com/en-us/azure/virtual-machines/states-billing#provisioning-states +func (vmPool *VMPool) getVMsFromCache(op skipOption) ([]compute.VirtualMachine, error) { + vmsMap := vmPool.manager.azureCache.getVirtualMachines() + var filteredVMs []compute.VirtualMachine + + for _, vm := range vmsMap[vmPool.agentPoolName] { + if vm.VirtualMachineProperties == nil || + vm.VirtualMachineProperties.HardwareProfile == nil || + !strings.EqualFold(string(vm.HardwareProfile.VMSize), vmPool.sku) { + continue + } + + if op.skipDeleting && strings.Contains(to.String(vm.VirtualMachineProperties.ProvisioningState), "Deleting") { + klog.V(4).Infof("Skipping VM %s in deleting state", to.String(vm.ID)) + continue + } + + if op.skipFailed && strings.Contains(to.String(vm.VirtualMachineProperties.ProvisioningState), "Failed") { + klog.V(4).Infof("Skipping VM %s in failed state", to.String(vm.ID)) + continue + } - return vmsPoolMap[agentPool.Name], nil + filteredVMs = append(filteredVMs, vm) + } + return filteredVMs, nil } // Nodes returns the list of nodes in the vms agentPool. -func (agentPool *VMsPool) Nodes() ([]cloudprovider.Instance, error) { - vms, err := agentPool.getVMsFromCache() +func (vmPool *VMPool) Nodes() ([]cloudprovider.Instance, error) { + vms, err := vmPool.getVMsFromCache(skipOption{}) // no skip option, get all VMs if err != nil { return nil, err } @@ -163,7 +447,7 @@ func (agentPool *VMsPool) Nodes() ([]cloudprovider.Instance, error) { if vm.ID == nil || len(*vm.ID) == 0 { continue } - resourceID, err := convertResourceGroupNameToLower("azure://" + *vm.ID) + resourceID, err := convertResourceGroupNameToLower("azure://" + to.String(vm.ID)) if err != nil { return nil, err } @@ -173,12 +457,54 @@ func (agentPool *VMsPool) Nodes() ([]cloudprovider.Instance, error) { return nodes, nil } -// TemplateNodeInfo is not implemented. -func (agentPool *VMsPool) TemplateNodeInfo() (*framework.NodeInfo, error) { - return nil, cloudprovider.ErrNotImplemented +// TemplateNodeInfo returns a NodeInfo object that can be used to create a new node in the vmPool. +func (vmPool *VMPool) TemplateNodeInfo() (*framework.NodeInfo, error) { + ap, err := vmPool.getAgentpoolFromCache() + if err != nil { + return nil, err + } + + inputLabels := map[string]string{} + inputTaints := "" + template, err := buildNodeTemplateFromVMPool(ap, vmPool.manager.config.Location, vmPool.sku, inputLabels, inputTaints) + if err != nil { + return nil, err + } + node, err := buildNodeFromTemplate(vmPool.agentPoolName, template, vmPool.manager, vmPool.manager.config.EnableDynamicInstanceList) + if err != nil { + return nil, err + } + + nodeInfo := framework.NewNodeInfo(node, nil, &framework.PodInfo{Pod: cloudprovider.BuildKubeProxy(vmPool.agentPoolName)}) + + return nodeInfo, nil +} + +func (vmPool *VMPool) getAgentpoolFromCache() (armcontainerservice.AgentPool, error) { + vmsPoolMap := vmPool.manager.azureCache.getVMsPoolMap() + if _, exists := vmsPoolMap[vmPool.agentPoolName]; !exists { + return armcontainerservice.AgentPool{}, fmt.Errorf("VMs agent pool %s not found in cache", vmPool.agentPoolName) + } + return vmsPoolMap[vmPool.agentPoolName], nil +} + +// getAgentpoolFromAzure returns the AKS agentpool from Azure +func (vmPool *VMPool) getAgentpoolFromAzure() (armcontainerservice.AgentPool, error) { + ctx, cancel := getContextWithTimeout(vmsContextTimeout) + defer cancel() + resp, err := vmPool.manager.azClient.agentPoolClient.Get( + ctx, + vmPool.manager.config.ClusterResourceGroup, + vmPool.manager.config.ClusterName, + vmPool.agentPoolName, nil) + if err != nil { + return resp.AgentPool, fmt.Errorf("failed to get agentpool %s in cluster %s with error: %v", + vmPool.agentPoolName, vmPool.manager.config.ClusterName, err) + } + return resp.AgentPool, nil } // AtomicIncreaseSize is not implemented. -func (agentPool *VMsPool) AtomicIncreaseSize(delta int) error { +func (vmPool *VMPool) AtomicIncreaseSize(delta int) error { return cloudprovider.ErrNotImplemented } diff --git a/cluster-autoscaler/cloudprovider/azure/azure_vms_pool_test.go b/cluster-autoscaler/cloudprovider/azure/azure_vms_pool_test.go index c1efbcac843bb0e38836614423f6db6dd9a90244..dd04875fce8f165ebafc2643f4cffd6f4cbd8ecc 100644 --- a/cluster-autoscaler/cloudprovider/azure/azure_vms_pool_test.go +++ b/cluster-autoscaler/cloudprovider/azure/azure_vms_pool_test.go @@ -17,45 +17,64 @@ limitations under the License. package azure import ( + "context" "fmt" + "net/http" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" "github.com/Azure/go-autorest/autorest/to" + "go.uber.org/mock/gomock" + "github.com/stretchr/testify/assert" apiv1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/autoscaler/cluster-autoscaler/cloudprovider" "k8s.io/autoscaler/cluster-autoscaler/config" + "k8s.io/autoscaler/cluster-autoscaler/config/dynamic" - providerazure "sigs.k8s.io/cloud-provider-azure/pkg/provider" + "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" ) -func newTestVMsPool(manager *AzureManager, name string) *VMsPool { - return &VMsPool{ +const ( + vmSku = "Standard_D2_v2" + vmsAgentPoolName = "test-vms-pool" + vmsNodeGroupName = vmsAgentPoolName + "/" + vmSku + fakeVMsNodeName = "aks-" + vmsAgentPoolName + "-13222729-vms%d" + fakeVMsPoolVMID = "/subscriptions/test-subscription-id/resourceGroups/test-rg/providers/Microsoft.Compute/virtualMachines/" + fakeVMsNodeName +) + +func newTestVMsPool(manager *AzureManager) *VMPool { + return &VMPool{ azureRef: azureRef{ - Name: name, + Name: vmsNodeGroupName, }, - manager: manager, - minSize: 3, - maxSize: 10, + manager: manager, + minSize: 3, + maxSize: 10, + agentPoolName: vmsAgentPoolName, + sku: vmSku, } } -const ( - fakeVMsPoolVMID = "/subscriptions/test-subscription-id/resourceGroups/test-rg/providers/Microsoft.Compute/virtualMachines/%d" -) - func newTestVMsPoolVMList(count int) []compute.VirtualMachine { var vmList []compute.VirtualMachine + for i := 0; i < count; i++ { vm := compute.VirtualMachine{ ID: to.StringPtr(fmt.Sprintf(fakeVMsPoolVMID, i)), VirtualMachineProperties: &compute.VirtualMachineProperties{ VMID: to.StringPtr(fmt.Sprintf("123E4567-E89B-12D3-A456-426655440000-%d", i)), + HardwareProfile: &compute.HardwareProfile{ + VMSize: compute.VirtualMachineSizeTypes(vmSku), + }, + ProvisioningState: to.StringPtr("Succeeded"), }, Tags: map[string]*string{ agentpoolTypeTag: to.StringPtr("VirtualMachines"), - agentpoolNameTag: to.StringPtr("test-vms-pool"), + agentpoolNameTag: to.StringPtr(vmsAgentPoolName), }, } vmList = append(vmList, vm) @@ -63,41 +82,73 @@ func newTestVMsPoolVMList(count int) []compute.VirtualMachine { return vmList } -func newVMsNode(vmID int64) *apiv1.Node { - node := &apiv1.Node{ +func newVMsNode(vmIdx int64) *apiv1.Node { + return &apiv1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf(fakeVMsNodeName, vmIdx), + }, Spec: apiv1.NodeSpec{ - ProviderID: "azure://" + fmt.Sprintf(fakeVMsPoolVMID, vmID), + ProviderID: "azure://" + fmt.Sprintf(fakeVMsPoolVMID, vmIdx), }, } - return node } -func TestNewVMsPool(t *testing.T) { - spec := &dynamic.NodeGroupSpec{ - Name: "test-nodepool", - MinSize: 1, - MaxSize: 5, +func getTestVMsAgentPool(isSystemPool bool) armcontainerservice.AgentPool { + mode := armcontainerservice.AgentPoolModeUser + if isSystemPool { + mode = armcontainerservice.AgentPoolModeSystem } - am := &AzureManager{ - config: &Config{ - Config: providerazure.Config{ - ResourceGroup: "test-resource-group", + vmsPoolType := armcontainerservice.AgentPoolTypeVirtualMachines + return armcontainerservice.AgentPool{ + Name: to.StringPtr(vmsAgentPoolName), + Properties: &armcontainerservice.ManagedClusterAgentPoolProfileProperties{ + Type: &vmsPoolType, + Mode: &mode, + VirtualMachinesProfile: &armcontainerservice.VirtualMachinesProfile{ + Scale: &armcontainerservice.ScaleProfile{ + Manual: []*armcontainerservice.ManualScaleProfile{ + { + Count: to.Int32Ptr(3), + Sizes: []*string{to.StringPtr(vmSku)}, + }, + }, + }, + }, + VirtualMachineNodesStatus: []*armcontainerservice.VirtualMachineNodes{ + { + Count: to.Int32Ptr(3), + Size: to.StringPtr(vmSku), + }, }, }, } +} + +func TestNewVMsPool(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockAgentpoolclient := NewMockAgentPoolsClient(ctrl) + manager := newTestAzureManager(t) + manager.azClient.agentPoolClient = mockAgentpoolclient + manager.config.ResourceGroup = "MC_rg" + manager.config.ClusterResourceGroup = "rg" + manager.config.ClusterName = "mycluster" - nodepool := NewVMsPool(spec, am) + spec := &dynamic.NodeGroupSpec{ + Name: vmsAgentPoolName, + MinSize: 1, + MaxSize: 10, + } - assert.Equal(t, "test-nodepool", nodepool.azureRef.Name) - assert.Equal(t, "test-resource-group", nodepool.resourceGroup) - assert.Equal(t, int64(-1), nodepool.curSize) - assert.Equal(t, 1, nodepool.minSize) - assert.Equal(t, 5, nodepool.maxSize) - assert.Equal(t, am, nodepool.manager) + ap, err := NewVMPool(spec, manager, vmsAgentPoolName, vmSku) + assert.NoError(t, err) + assert.Equal(t, vmsAgentPoolName, ap.azureRef.Name) + assert.Equal(t, 1, ap.minSize) + assert.Equal(t, 10, ap.maxSize) } func TestMinSize(t *testing.T) { - agentPool := &VMsPool{ + agentPool := &VMPool{ minSize: 1, } @@ -105,12 +156,12 @@ func TestMinSize(t *testing.T) { } func TestExist(t *testing.T) { - agentPool := &VMsPool{} + agentPool := &VMPool{} assert.True(t, agentPool.Exist()) } func TestCreate(t *testing.T) { - agentPool := &VMsPool{} + agentPool := &VMPool{} nodeGroup, err := agentPool.Create() assert.Nil(t, nodeGroup) @@ -118,65 +169,43 @@ func TestCreate(t *testing.T) { } func TestDelete(t *testing.T) { - agentPool := &VMsPool{} + agentPool := &VMPool{} err := agentPool.Delete() assert.Equal(t, cloudprovider.ErrNotImplemented, err) } func TestAutoprovisioned(t *testing.T) { - agentPool := &VMsPool{} + agentPool := &VMPool{} assert.False(t, agentPool.Autoprovisioned()) } func TestGetOptions(t *testing.T) { - agentPool := &VMsPool{} + agentPool := &VMPool{} defaults := config.NodeGroupAutoscalingOptions{} options, err := agentPool.GetOptions(defaults) assert.Nil(t, options) - assert.Equal(t, cloudprovider.ErrNotImplemented, err) + assert.Nil(t, err) } func TestMaxSize(t *testing.T) { - agentPool := &VMsPool{ + agentPool := &VMPool{ maxSize: 10, } assert.Equal(t, 10, agentPool.MaxSize()) } -func TestTargetSize(t *testing.T) { - agentPool := &VMsPool{} - - size, err := agentPool.TargetSize() - assert.Equal(t, -1, size) - assert.Equal(t, cloudprovider.ErrNotImplemented, err) -} - -func TestIncreaseSize(t *testing.T) { - agentPool := &VMsPool{} - - err := agentPool.IncreaseSize(1) - assert.Equal(t, cloudprovider.ErrNotImplemented, err) -} - -func TestDeleteNodes(t *testing.T) { - agentPool := &VMsPool{} - - err := agentPool.DeleteNodes(nil) - assert.Equal(t, cloudprovider.ErrNotImplemented, err) -} - func TestDecreaseTargetSize(t *testing.T) { - agentPool := &VMsPool{} + agentPool := newTestVMsPool(newTestAzureManager(t)) err := agentPool.DecreaseTargetSize(1) - assert.Equal(t, cloudprovider.ErrNotImplemented, err) + assert.Nil(t, err) } func TestId(t *testing.T) { - agentPool := &VMsPool{ + agentPool := &VMPool{ azureRef: azureRef{ Name: "test-id", }, @@ -186,7 +215,7 @@ func TestId(t *testing.T) { } func TestDebug(t *testing.T) { - agentPool := &VMsPool{ + agentPool := &VMPool{ azureRef: azureRef{ Name: "test-debug", }, @@ -198,115 +227,341 @@ func TestDebug(t *testing.T) { assert.Equal(t, expectedDebugString, agentPool.Debug()) } func TestTemplateNodeInfo(t *testing.T) { - agentPool := &VMsPool{} + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ap := newTestVMsPool(newTestAzureManager(t)) + ap.manager.config.EnableVMsAgentPool = true + mockAgentpoolclient := NewMockAgentPoolsClient(ctrl) + ap.manager.azClient.agentPoolClient = mockAgentpoolclient + agentpool := getTestVMsAgentPool(false) + fakeAPListPager := getFakeAgentpoolListPager(&agentpool) + mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil). + Return(fakeAPListPager) + + ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config) + assert.NoError(t, err) + ap.manager.azureCache = ac - nodeInfo, err := agentPool.TemplateNodeInfo() - assert.Nil(t, nodeInfo) - assert.Equal(t, cloudprovider.ErrNotImplemented, err) + nodeInfo, err := ap.TemplateNodeInfo() + assert.NotNil(t, nodeInfo) + assert.Nil(t, err) } func TestAtomicIncreaseSize(t *testing.T) { - agentPool := &VMsPool{} + agentPool := &VMPool{} err := agentPool.AtomicIncreaseSize(1) assert.Equal(t, cloudprovider.ErrNotImplemented, err) } -// Test cases for getVMsFromCache() -// Test case 1 - when the vms pool is not found in the cache -// Test case 2 - when the vms pool is found in the cache but has no VMs -// Test case 3 - when the vms pool is found in the cache and has VMs -// Test case 4 - when the vms pool is found in the cache and has VMs with no name func TestGetVMsFromCache(t *testing.T) { - // Test case 1 manager := &AzureManager{ azureCache: &azureCache{ virtualMachines: make(map[string][]compute.VirtualMachine), + vmsPoolMap: make(map[string]armcontainerservice.AgentPool), }, } - agentPool := &VMsPool{ - manager: manager, - azureRef: azureRef{ - Name: "test-vms-pool", - }, + agentPool := &VMPool{ + manager: manager, + agentPoolName: vmsAgentPoolName, + sku: vmSku, } - _, err := agentPool.getVMsFromCache() - assert.EqualError(t, err, "vms pool test-vms-pool not found in the cache") + // Test case 1 - when the vms pool is not found in the cache + vms, err := agentPool.getVMsFromCache(skipOption{}) + assert.Nil(t, err) + assert.Len(t, vms, 0) - // Test case 2 - manager.azureCache.virtualMachines["test-vms-pool"] = []compute.VirtualMachine{} - _, err = agentPool.getVMsFromCache() + // Test case 2 - when the vms pool is found in the cache but has no VMs + manager.azureCache.virtualMachines[vmsAgentPoolName] = []compute.VirtualMachine{} + vms, err = agentPool.getVMsFromCache(skipOption{}) assert.NoError(t, err) + assert.Len(t, vms, 0) - // Test case 3 - manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3) - vms, err := agentPool.getVMsFromCache() + // Test case 3 - when the vms pool is found in the cache and has VMs + manager.azureCache.virtualMachines[vmsAgentPoolName] = newTestVMsPoolVMList(3) + vms, err = agentPool.getVMsFromCache(skipOption{}) assert.NoError(t, err) assert.Len(t, vms, 3) - // Test case 4 - manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3) - agentPool.azureRef.Name = "" - _, err = agentPool.getVMsFromCache() - assert.EqualError(t, err, "vms pool not found in the cache") + // Test case 4 - should skip failed VMs + vmList := newTestVMsPoolVMList(3) + vmList[0].VirtualMachineProperties.ProvisioningState = to.StringPtr("Failed") + manager.azureCache.virtualMachines[vmsAgentPoolName] = vmList + vms, err = agentPool.getVMsFromCache(skipOption{skipFailed: true}) + assert.NoError(t, err) + assert.Len(t, vms, 2) + + // Test case 5 - should skip deleting VMs + vmList = newTestVMsPoolVMList(3) + vmList[0].VirtualMachineProperties.ProvisioningState = to.StringPtr("Deleting") + manager.azureCache.virtualMachines[vmsAgentPoolName] = vmList + vms, err = agentPool.getVMsFromCache(skipOption{skipDeleting: true}) + assert.NoError(t, err) + assert.Len(t, vms, 2) + + // Test case 6 - should not skip deleting VMs + vmList = newTestVMsPoolVMList(3) + vmList[0].VirtualMachineProperties.ProvisioningState = to.StringPtr("Deleting") + manager.azureCache.virtualMachines[vmsAgentPoolName] = vmList + vms, err = agentPool.getVMsFromCache(skipOption{skipFailed: true}) + assert.NoError(t, err) + assert.Len(t, vms, 3) + + // Test case 7 - when the vms pool is found in the cache and has VMs with no name + manager.azureCache.virtualMachines[vmsAgentPoolName] = newTestVMsPoolVMList(3) + agentPool.agentPoolName = "" + vms, err = agentPool.getVMsFromCache(skipOption{}) + assert.NoError(t, err) + assert.Len(t, vms, 0) +} + +func TestGetVMsFromCacheForVMsPool(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ap := newTestVMsPool(newTestAzureManager(t)) + + expectedVMs := newTestVMsPoolVMList(2) + mockVMClient := mockvmclient.NewMockInterface(ctrl) + ap.manager.azClient.virtualMachinesClient = mockVMClient + ap.manager.config.EnableVMsAgentPool = true + mockAgentpoolclient := NewMockAgentPoolsClient(ctrl) + ap.manager.azClient.agentPoolClient = mockAgentpoolclient + mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil) + + agentpool := getTestVMsAgentPool(false) + fakeAPListPager := getFakeAgentpoolListPager(&agentpool) + mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil). + Return(fakeAPListPager) + + ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config) + assert.NoError(t, err) + ac.enableVMsAgentPool = true + ap.manager.azureCache = ac + + vms, err := ap.getVMsFromCache(skipOption{}) + assert.Equal(t, 2, len(vms)) + assert.NoError(t, err) } -// Test cases for Nodes() -// Test case 1 - when there are no VMs in the pool -// Test case 2 - when there are VMs in the pool -// Test case 3 - when there are VMs in the pool with no ID -// Test case 4 - when there is an error converting resource group name -// Test case 5 - when there is an error getting VMs from cache func TestNodes(t *testing.T) { - // Test case 1 - manager := &AzureManager{ - azureCache: &azureCache{ - virtualMachines: make(map[string][]compute.VirtualMachine), + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ap := newTestVMsPool(newTestAzureManager(t)) + expectedVMs := newTestVMsPoolVMList(2) + + mockVMClient := mockvmclient.NewMockInterface(ctrl) + ap.manager.azClient.virtualMachinesClient = mockVMClient + mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil) + + ap.manager.config.EnableVMsAgentPool = true + mockAgentpoolclient := NewMockAgentPoolsClient(ctrl) + ap.manager.azClient.agentPoolClient = mockAgentpoolclient + agentpool := getTestVMsAgentPool(false) + fakeAPListPager := getFakeAgentpoolListPager(&agentpool) + mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil). + Return(fakeAPListPager) + + ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config) + assert.NoError(t, err) + ap.manager.azureCache = ac + + vms, err := ap.Nodes() + assert.Equal(t, 2, len(vms)) + assert.NoError(t, err) +} + +func TestGetCurSizeForVMsPool(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ap := newTestVMsPool(newTestAzureManager(t)) + expectedVMs := newTestVMsPoolVMList(3) + + mockVMClient := mockvmclient.NewMockInterface(ctrl) + ap.manager.azClient.virtualMachinesClient = mockVMClient + mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil) + + ap.manager.config.EnableVMsAgentPool = true + mockAgentpoolclient := NewMockAgentPoolsClient(ctrl) + ap.manager.azClient.agentPoolClient = mockAgentpoolclient + agentpool := getTestVMsAgentPool(false) + fakeAPListPager := getFakeAgentpoolListPager(&agentpool) + mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil). + Return(fakeAPListPager) + + ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config) + assert.NoError(t, err) + ap.manager.azureCache = ac + + curSize, err := ap.getCurSize(skipOption{}) + assert.NoError(t, err) + assert.Equal(t, int32(3), curSize) +} + +func TestVMsPoolIncreaseSize(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + manager := newTestAzureManager(t) + + ap := newTestVMsPool(manager) + expectedVMs := newTestVMsPoolVMList(3) + + mockVMClient := mockvmclient.NewMockInterface(ctrl) + ap.manager.azClient.virtualMachinesClient = mockVMClient + mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil) + + ap.manager.config.EnableVMsAgentPool = true + mockAgentpoolclient := NewMockAgentPoolsClient(ctrl) + ap.manager.azClient.agentPoolClient = mockAgentpoolclient + agentpool := getTestVMsAgentPool(false) + fakeAPListPager := getFakeAgentpoolListPager(&agentpool) + mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil). + Return(fakeAPListPager) + + ac, err := newAzureCache(ap.manager.azClient, refreshInterval, *ap.manager.config) + assert.NoError(t, err) + ap.manager.azureCache = ac + + // failure case 1 + err1 := ap.IncreaseSize(-1) + expectedErr := fmt.Errorf("size increase must be positive, current delta: -1") + assert.Equal(t, expectedErr, err1) + + // failure case 2 + err2 := ap.IncreaseSize(8) + expectedErr = fmt.Errorf("size-increasing request of 11 is bigger than max size 10") + assert.Equal(t, expectedErr, err2) + + // success case 3 + resp := &http.Response{ + Header: map[string][]string{ + "Fake-Poller-Status": {"Done"}, }, } - agentPool := &VMsPool{ - manager: manager, - azureRef: azureRef{ - Name: "test-vms-pool", + + fakePoller, pollerErr := runtime.NewPoller(resp, runtime.Pipeline{}, + &runtime.NewPollerOptions[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse]{ + Handler: &fakehandler[armcontainerservice.AgentPoolsClientCreateOrUpdateResponse]{}, + }) + + assert.NoError(t, pollerErr) + + mockAgentpoolclient.EXPECT().BeginCreateOrUpdate( + gomock.Any(), manager.config.ClusterResourceGroup, + manager.config.ClusterName, + vmsAgentPoolName, + gomock.Any(), gomock.Any()).Return(fakePoller, nil) + + err3 := ap.IncreaseSize(1) + assert.NoError(t, err3) +} + +func TestDeleteVMsPoolNodes_Failed(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ap := newTestVMsPool(newTestAzureManager(t)) + node := newVMsNode(0) + + expectedVMs := newTestVMsPoolVMList(3) + mockVMClient := mockvmclient.NewMockInterface(ctrl) + ap.manager.azClient.virtualMachinesClient = mockVMClient + ap.manager.config.EnableVMsAgentPool = true + mockAgentpoolclient := NewMockAgentPoolsClient(ctrl) + agentpool := getTestVMsAgentPool(false) + ap.manager.azClient.agentPoolClient = mockAgentpoolclient + fakeAPListPager := getFakeAgentpoolListPager(&agentpool) + mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).Return(fakeAPListPager) + mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil) + + ap.manager.azureCache.enableVMsAgentPool = true + registered := ap.manager.RegisterNodeGroup(ap) + assert.True(t, registered) + + ap.manager.explicitlyConfigured[vmsNodeGroupName] = true + ap.manager.forceRefresh() + + // failure case + deleteErr := ap.DeleteNodes([]*apiv1.Node{node}) + assert.Error(t, deleteErr) + assert.Contains(t, deleteErr.Error(), "cannot delete nodes as minimum size of 3 has been reached") +} + +func TestDeleteVMsPoolNodes_Success(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ap := newTestVMsPool(newTestAzureManager(t)) + + expectedVMs := newTestVMsPoolVMList(5) + mockVMClient := mockvmclient.NewMockInterface(ctrl) + ap.manager.azClient.virtualMachinesClient = mockVMClient + ap.manager.config.EnableVMsAgentPool = true + mockAgentpoolclient := NewMockAgentPoolsClient(ctrl) + agentpool := getTestVMsAgentPool(false) + ap.manager.azClient.agentPoolClient = mockAgentpoolclient + fakeAPListPager := getFakeAgentpoolListPager(&agentpool) + mockAgentpoolclient.EXPECT().NewListPager(gomock.Any(), gomock.Any(), nil).Return(fakeAPListPager) + mockVMClient.EXPECT().List(gomock.Any(), ap.manager.config.ResourceGroup).Return(expectedVMs, nil) + + ap.manager.azureCache.enableVMsAgentPool = true + registered := ap.manager.RegisterNodeGroup(ap) + assert.True(t, registered) + + ap.manager.explicitlyConfigured[vmsNodeGroupName] = true + ap.manager.forceRefresh() + + // success case + resp := &http.Response{ + Header: map[string][]string{ + "Fake-Poller-Status": {"Done"}, }, } + fakePoller, err := runtime.NewPoller(resp, runtime.Pipeline{}, + &runtime.NewPollerOptions[armcontainerservice.AgentPoolsClientDeleteMachinesResponse]{ + Handler: &fakehandler[armcontainerservice.AgentPoolsClientDeleteMachinesResponse]{}, + }) + assert.NoError(t, err) + + mockAgentpoolclient.EXPECT().BeginDeleteMachines( + gomock.Any(), ap.manager.config.ClusterResourceGroup, + ap.manager.config.ClusterName, + vmsAgentPoolName, + gomock.Any(), gomock.Any()).Return(fakePoller, nil) + node := newVMsNode(0) + derr := ap.DeleteNodes([]*apiv1.Node{node}) + assert.NoError(t, derr) +} - nodes, err := agentPool.Nodes() - assert.EqualError(t, err, "vms pool test-vms-pool not found in the cache") - assert.Empty(t, nodes) +type fakehandler[T any] struct{} - // Test case 2 - manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3) - nodes, err = agentPool.Nodes() - assert.NoError(t, err) - assert.Len(t, nodes, 3) +func (f *fakehandler[T]) Done() bool { + return true +} - // Test case 3 - manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3) - manager.azureCache.virtualMachines["test-vms-pool"][0].ID = nil - nodes, err = agentPool.Nodes() - assert.NoError(t, err) - assert.Len(t, nodes, 2) - manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3) - emptyString := "" - manager.azureCache.virtualMachines["test-vms-pool"][0].ID = &emptyString - nodes, err = agentPool.Nodes() - assert.NoError(t, err) - assert.Len(t, nodes, 2) - - // Test case 4 - manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(3) - bogusID := "foo" - manager.azureCache.virtualMachines["test-vms-pool"][0].ID = &bogusID - nodes, err = agentPool.Nodes() - assert.Empty(t, nodes) - assert.Error(t, err) - - // Test case 5 - manager.azureCache.virtualMachines["test-vms-pool"] = newTestVMsPoolVMList(1) - agentPool.azureRef.Name = "" - nodes, err = agentPool.Nodes() - assert.Empty(t, nodes) - assert.Error(t, err) +func (f *fakehandler[T]) Poll(ctx context.Context) (*http.Response, error) { + return nil, nil +} + +func (f *fakehandler[T]) Result(ctx context.Context, out *T) error { + return nil +} + +func getFakeAgentpoolListPager(agentpool ...*armcontainerservice.AgentPool) *runtime.Pager[armcontainerservice.AgentPoolsClientListResponse] { + fakeFetcher := func(ctx context.Context, response *armcontainerservice.AgentPoolsClientListResponse) (armcontainerservice.AgentPoolsClientListResponse, error) { + return armcontainerservice.AgentPoolsClientListResponse{ + AgentPoolListResult: armcontainerservice.AgentPoolListResult{ + Value: agentpool, + }, + }, nil + } + + return runtime.NewPager(runtime.PagingHandler[armcontainerservice.AgentPoolsClientListResponse]{ + More: func(response armcontainerservice.AgentPoolsClientListResponse) bool { + return false + }, + Fetcher: fakeFetcher, + }) } diff --git a/cluster-autoscaler/go.mod b/cluster-autoscaler/go.mod index 3c93e7d2b766af9c4c969bd3d5b374a198005a2d..93060dea47fe08fbed69e8470b8f312ae1d09eca 100644 --- a/cluster-autoscaler/go.mod +++ b/cluster-autoscaler/go.mod @@ -8,9 +8,9 @@ require ( cloud.google.com/go/compute/metadata v0.3.0 github.com/Azure/azure-sdk-for-go v68.0.0+incompatible github.com/Azure/azure-sdk-for-go-extensions v0.1.6 - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 - github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2 - github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.9.0-beta.1 + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.1.0-beta.2 github.com/Azure/go-autorest/autorest v0.11.29 github.com/Azure/go-autorest/autorest/adal v0.9.24 github.com/Azure/go-autorest/autorest/azure/auth v0.5.13 @@ -64,11 +64,12 @@ require ( require ( cel.dev/expr v0.18.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.12.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v5 v5.6.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry v1.2.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.8.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.4.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v4 v4.3.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns v1.2.0 // indirect diff --git a/cluster-autoscaler/go.sum b/cluster-autoscaler/go.sum index b6d59275d5c852be29504f22250d12847ea17828..ee6158e42a1b3e2960de7cd41b908ed5d3bb45a5 100644 --- a/cluster-autoscaler/go.sum +++ b/cluster-autoscaler/go.sum @@ -9,12 +9,12 @@ github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0 github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/azure-sdk-for-go-extensions v0.1.6 h1:EXGvDcj54u98XfaI/Cy65Ds6vNsIJeGKYf0eNLB1y4Q= github.com/Azure/azure-sdk-for-go-extensions v0.1.6/go.mod h1:27StPiXJp6Xzkq2AQL7gPK7VC0hgmCnUKlco1dO1jaM= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2 h1:FDif4R1+UUR+00q6wquyX90K7A8dN+R5E8GEadoP7sU= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2/go.mod h1:aiYBYui4BJ/BJCAIKs92XiPyQfTaBWqvHujDwKb6CBU= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 h1:LqbJ/WzJUwBf8UiaSzgX7aMclParm9/5Vgp+TY51uBQ= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0 h1:GJHeeA2N7xrG3q30L2UXDyuWRzDM900/65j70wcM4Ww= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.13.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.12.0 h1:xnO4sFyG8UH2fElBkcqLTOZsAajvKfnSlgBBW8dXYjw= github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.12.0/go.mod h1:XD3DIOOVgBCO03OleB1fHjgktVRFxlT++KwKgIOewdM= github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.1 h1:FbH3BbSb4bvGluTesZZ+ttN/MDsnMmQP36OSnDuSXqw= @@ -25,10 +25,14 @@ github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armconta github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry v1.2.0/go.mod h1:E7ltexgRDmeJ0fJWv0D/HLwY2xbDdN+uv+X2uZtOx3w= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v2 v2.4.0 h1:1u/K2BFv0MwkG6he8RYuUcbbeK22rkoZbg4lKa/msZU= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v2 v2.4.0/go.mod h1:U5gpsREQZE6SLk1t/cFfc1eMhYAlYpEzvaYXuDfefy8= -github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.9.0-beta.1 h1:iqhrjj9w9/AQZsHjaOVyloamkeAFRbWI0iHNy6INMYk= -github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.9.0-beta.1/go.mod h1:gYq8wyDgv6JLhGbAU6gg8amCPgQWRE+aCvrV2gyzdfs= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.8.0 h1:0nGmzwBv5ougvzfGPCO2ljFRHvun57KpNrVCMrlk0ns= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4 v4.8.0/go.mod h1:gYq8wyDgv6JLhGbAU6gg8amCPgQWRE+aCvrV2gyzdfs= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.1.0-beta.2 h1:re+BEe/OafvSyRy2vM+Fyu+EcUK34O2o/Fa6WO3ITZM= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.1.0-beta.2/go.mod h1:5zx285T5OLk+iQbfOuexhhO7J6dfzkqVkFgS/+s7XaA= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0 h1:PTFGRSlMKCQelWwxUyYVEUqseBJVemLyqWJjvMyt0do= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0/go.mod h1:LRr2FzBTQlONPPa5HREE5+RjSCTXl7BwOvYOaWTqCaI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.0 h1:2qsIIvxVT+uE6yrNldntJKlLRgxGbZ85kgtz5SNBhMw= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.0/go.mod h1:AW8VEadnhw9xox+VaVd9sP7NjzOAnaZBLRH6Tq3cJ38= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.4.0 h1:HlZMUZW8S4P9oob1nCHxCCKrytxyLc+24nUJGssoEto= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.4.0/go.mod h1:StGsLbuJh06Bd8IBfnAlIFV3fLb+gkczONWf15hpX2E= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/managementgroups/armmanagementgroups v1.0.0 h1:pPvTJ1dY0sA35JOeFq6TsY2xj6Z85Yo23Pj4wCCvu4o=