Skip to content

Commit 3284a8c

Browse files
authored
Improved usability of databricks auth login ... --configure-cluster flow by displaying cluster type and runtime version (#956)
This PR adds selectors for Databricks-connect compatible clusters and SQL warehouses Tested in #914
1 parent f111b08 commit 3284a8c

File tree

5 files changed

+475
-11
lines changed

5 files changed

+475
-11
lines changed

cmd/auth/login.go

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ import (
88
"github.com/databricks/cli/libs/auth"
99
"github.com/databricks/cli/libs/cmdio"
1010
"github.com/databricks/cli/libs/databrickscfg"
11+
"github.com/databricks/cli/libs/databrickscfg/cfgpickers"
1112
"github.com/databricks/databricks-sdk-go"
1213
"github.com/databricks/databricks-sdk-go/config"
13-
"github.com/databricks/databricks-sdk-go/service/compute"
1414
"github.com/spf13/cobra"
1515
)
1616

@@ -28,6 +28,8 @@ func configureHost(ctx context.Context, persistentAuth *auth.PersistentAuth, arg
2828
return nil
2929
}
3030

31+
const minimalDbConnectVersion = "13.1"
32+
3133
func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
3234
cmd := &cobra.Command{
3335
Use: "login [HOST]",
@@ -95,19 +97,12 @@ func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
9597
return err
9698
}
9799
ctx := cmd.Context()
98-
99-
promptSpinner := cmdio.Spinner(ctx)
100-
promptSpinner <- "Loading list of clusters to select from"
101-
names, err := w.Clusters.ClusterDetailsClusterNameToClusterIdMap(ctx, compute.ListClustersRequest{})
102-
close(promptSpinner)
103-
if err != nil {
104-
return fmt.Errorf("failed to load clusters list. Original error: %w", err)
105-
}
106-
clusterId, err := cmdio.Select(ctx, names, "Choose cluster")
100+
clusterID, err := cfgpickers.AskForCluster(ctx, w,
101+
cfgpickers.WithDatabricksConnect(minimalDbConnectVersion))
107102
if err != nil {
108103
return err
109104
}
110-
cfg.ClusterID = clusterId
105+
cfg.ClusterID = clusterID
111106
}
112107

113108
if profileName != "" {
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
package cfgpickers
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"regexp"
8+
"strings"
9+
10+
"github.com/databricks/cli/libs/cmdio"
11+
"github.com/databricks/databricks-sdk-go"
12+
"github.com/databricks/databricks-sdk-go/service/compute"
13+
"github.com/databricks/databricks-sdk-go/service/iam"
14+
"github.com/fatih/color"
15+
"github.com/manifoldco/promptui"
16+
"golang.org/x/mod/semver"
17+
)
18+
19+
var minUcRuntime = canonicalVersion("v12.0")
20+
21+
var dbrVersionRegex = regexp.MustCompile(`^(\d+\.\d+)\.x-.*`)
22+
var dbrSnapshotVersionRegex = regexp.MustCompile(`^(\d+)\.x-snapshot.*`)
23+
24+
func canonicalVersion(v string) string {
25+
return semver.Canonical("v" + strings.TrimPrefix(v, "v"))
26+
}
27+
28+
func GetRuntimeVersion(cluster compute.ClusterDetails) (string, bool) {
29+
match := dbrVersionRegex.FindStringSubmatch(cluster.SparkVersion)
30+
if len(match) < 1 {
31+
match = dbrSnapshotVersionRegex.FindStringSubmatch(cluster.SparkVersion)
32+
if len(match) > 1 {
33+
// we return 14.999 for 14.x-snapshot for semver.Compare() to work properly
34+
return fmt.Sprintf("%s.999", match[1]), true
35+
}
36+
return "", false
37+
}
38+
return match[1], true
39+
}
40+
41+
func IsCompatibleWithUC(cluster compute.ClusterDetails, minVersion string) bool {
42+
minVersion = canonicalVersion(minVersion)
43+
if semver.Compare(minUcRuntime, minVersion) >= 0 {
44+
return false
45+
}
46+
runtimeVersion, ok := GetRuntimeVersion(cluster)
47+
if !ok {
48+
return false
49+
}
50+
clusterRuntime := canonicalVersion(runtimeVersion)
51+
if semver.Compare(minVersion, clusterRuntime) > 0 {
52+
return false
53+
}
54+
switch cluster.DataSecurityMode {
55+
case compute.DataSecurityModeUserIsolation, compute.DataSecurityModeSingleUser:
56+
return true
57+
default:
58+
return false
59+
}
60+
}
61+
62+
var ErrNoCompatibleClusters = errors.New("no compatible clusters found")
63+
64+
type compatibleCluster struct {
65+
compute.ClusterDetails
66+
versionName string
67+
}
68+
69+
func (v compatibleCluster) Access() string {
70+
switch v.DataSecurityMode {
71+
case compute.DataSecurityModeUserIsolation:
72+
return "Shared"
73+
case compute.DataSecurityModeSingleUser:
74+
return "Assigned"
75+
default:
76+
return "Unknown"
77+
}
78+
}
79+
80+
func (v compatibleCluster) Runtime() string {
81+
runtime, _, _ := strings.Cut(v.versionName, " (")
82+
return runtime
83+
}
84+
85+
func (v compatibleCluster) State() string {
86+
state := v.ClusterDetails.State
87+
switch state {
88+
case compute.StateRunning, compute.StateResizing:
89+
return color.GreenString(state.String())
90+
case compute.StateError, compute.StateTerminated, compute.StateTerminating, compute.StateUnknown:
91+
return color.RedString(state.String())
92+
default:
93+
return color.BlueString(state.String())
94+
}
95+
}
96+
97+
type clusterFilter func(cluster *compute.ClusterDetails, me *iam.User) bool
98+
99+
func WithDatabricksConnect(minVersion string) func(*compute.ClusterDetails, *iam.User) bool {
100+
return func(cluster *compute.ClusterDetails, me *iam.User) bool {
101+
if !IsCompatibleWithUC(*cluster, minVersion) {
102+
return false
103+
}
104+
switch cluster.ClusterSource {
105+
case compute.ClusterSourceJob,
106+
compute.ClusterSourceModels,
107+
compute.ClusterSourcePipeline,
108+
compute.ClusterSourcePipelineMaintenance,
109+
compute.ClusterSourceSql:
110+
// only UI and API clusters are usable for DBConnect.
111+
// `CanUseClient: "NOTEBOOKS"`` didn't seem to have an effect.
112+
return false
113+
}
114+
if cluster.SingleUserName != "" && cluster.SingleUserName != me.UserName {
115+
return false
116+
}
117+
return true
118+
}
119+
}
120+
121+
func loadInteractiveClusters(ctx context.Context, w *databricks.WorkspaceClient, filters []clusterFilter) ([]compatibleCluster, error) {
122+
promptSpinner := cmdio.Spinner(ctx)
123+
promptSpinner <- "Loading list of clusters to select from"
124+
defer close(promptSpinner)
125+
all, err := w.Clusters.ListAll(ctx, compute.ListClustersRequest{
126+
CanUseClient: "NOTEBOOKS",
127+
})
128+
if err != nil {
129+
return nil, fmt.Errorf("list clusters: %w", err)
130+
}
131+
me, err := w.CurrentUser.Me(ctx)
132+
if err != nil {
133+
return nil, fmt.Errorf("current user: %w", err)
134+
}
135+
versions := map[string]string{}
136+
sv, err := w.Clusters.SparkVersions(ctx)
137+
if err != nil {
138+
return nil, fmt.Errorf("list runtime versions: %w", err)
139+
}
140+
for _, v := range sv.Versions {
141+
versions[v.Key] = v.Name
142+
}
143+
var compatible []compatibleCluster
144+
for _, cluster := range all {
145+
var skip bool
146+
for _, filter := range filters {
147+
if !filter(&cluster, me) {
148+
skip = true
149+
}
150+
}
151+
if skip {
152+
continue
153+
}
154+
compatible = append(compatible, compatibleCluster{
155+
ClusterDetails: cluster,
156+
versionName: versions[cluster.SparkVersion],
157+
})
158+
}
159+
return compatible, nil
160+
}
161+
162+
func AskForCluster(ctx context.Context, w *databricks.WorkspaceClient, filters ...clusterFilter) (string, error) {
163+
compatible, err := loadInteractiveClusters(ctx, w, filters)
164+
if err != nil {
165+
return "", fmt.Errorf("load: %w", err)
166+
}
167+
if len(compatible) == 0 {
168+
return "", ErrNoCompatibleClusters
169+
}
170+
if len(compatible) == 1 {
171+
return compatible[0].ClusterId, nil
172+
}
173+
i, _, err := cmdio.RunSelect(ctx, &promptui.Select{
174+
Label: "Choose compatible cluster",
175+
Items: compatible,
176+
Searcher: func(input string, idx int) bool {
177+
lower := strings.ToLower(compatible[idx].ClusterName)
178+
return strings.Contains(lower, input)
179+
},
180+
StartInSearchMode: true,
181+
Templates: &promptui.SelectTemplates{
182+
Label: "{{.ClusterName | faint}}",
183+
Active: `{{.ClusterName | bold}} ({{.State}} {{.Access}} Runtime {{.Runtime}}) ({{.ClusterId | faint}})`,
184+
Inactive: `{{.ClusterName}} ({{.State}} {{.Access}} Runtime {{.Runtime}})`,
185+
Selected: `{{ "Configured cluster" | faint }}: {{ .ClusterName | bold }} ({{.ClusterId | faint}})`,
186+
},
187+
})
188+
if err != nil {
189+
return "", err
190+
}
191+
return compatible[i].ClusterId, nil
192+
}
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package cfgpickers
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"testing"
7+
8+
"github.com/databricks/cli/libs/cmdio"
9+
"github.com/databricks/cli/libs/flags"
10+
"github.com/databricks/databricks-sdk-go"
11+
"github.com/databricks/databricks-sdk-go/qa"
12+
"github.com/databricks/databricks-sdk-go/service/compute"
13+
"github.com/databricks/databricks-sdk-go/service/iam"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
func TestIsCompatible(t *testing.T) {
18+
require.True(t, IsCompatibleWithUC(compute.ClusterDetails{
19+
SparkVersion: "13.2.x-aarch64-scala2.12",
20+
DataSecurityMode: compute.DataSecurityModeUserIsolation,
21+
}, "13.0"))
22+
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
23+
SparkVersion: "13.2.x-aarch64-scala2.12",
24+
DataSecurityMode: compute.DataSecurityModeNone,
25+
}, "13.0"))
26+
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
27+
SparkVersion: "9.1.x-photon-scala2.12",
28+
DataSecurityMode: compute.DataSecurityModeNone,
29+
}, "13.0"))
30+
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
31+
SparkVersion: "9.1.x-photon-scala2.12",
32+
DataSecurityMode: compute.DataSecurityModeNone,
33+
}, "10.0"))
34+
require.False(t, IsCompatibleWithUC(compute.ClusterDetails{
35+
SparkVersion: "custom-9.1.x-photon-scala2.12",
36+
DataSecurityMode: compute.DataSecurityModeNone,
37+
}, "14.0"))
38+
}
39+
40+
func TestIsCompatibleWithSnapshots(t *testing.T) {
41+
require.True(t, IsCompatibleWithUC(compute.ClusterDetails{
42+
SparkVersion: "14.x-snapshot-cpu-ml-scala2.12",
43+
DataSecurityMode: compute.DataSecurityModeUserIsolation,
44+
}, "14.0"))
45+
}
46+
47+
func TestFirstCompatibleCluster(t *testing.T) {
48+
cfg, server := qa.HTTPFixtures{
49+
{
50+
Method: "GET",
51+
Resource: "/api/2.0/clusters/list?can_use_client=NOTEBOOKS",
52+
Response: compute.ListClustersResponse{
53+
Clusters: []compute.ClusterDetails{
54+
{
55+
ClusterId: "abc-id",
56+
ClusterName: "first shared",
57+
DataSecurityMode: compute.DataSecurityModeUserIsolation,
58+
SparkVersion: "12.2.x-whatever",
59+
State: compute.StateRunning,
60+
},
61+
{
62+
ClusterId: "bcd-id",
63+
ClusterName: "second personal",
64+
DataSecurityMode: compute.DataSecurityModeSingleUser,
65+
SparkVersion: "14.5.x-whatever",
66+
State: compute.StateRunning,
67+
SingleUserName: "serge",
68+
},
69+
},
70+
},
71+
},
72+
{
73+
Method: "GET",
74+
Resource: "/api/2.0/preview/scim/v2/Me",
75+
Response: iam.User{
76+
UserName: "serge",
77+
},
78+
},
79+
{
80+
Method: "GET",
81+
Resource: "/api/2.0/clusters/spark-versions",
82+
Response: compute.GetSparkVersionsResponse{
83+
Versions: []compute.SparkVersion{
84+
{
85+
Key: "14.5.x-whatever",
86+
Name: "14.5 (Awesome)",
87+
},
88+
},
89+
},
90+
},
91+
}.Config(t)
92+
defer server.Close()
93+
w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg)))
94+
95+
ctx := context.Background()
96+
ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "..."))
97+
clusterID, err := AskForCluster(ctx, w, WithDatabricksConnect("13.1"))
98+
require.NoError(t, err)
99+
require.Equal(t, "bcd-id", clusterID)
100+
}
101+
102+
func TestNoCompatibleClusters(t *testing.T) {
103+
cfg, server := qa.HTTPFixtures{
104+
{
105+
Method: "GET",
106+
Resource: "/api/2.0/clusters/list?can_use_client=NOTEBOOKS",
107+
Response: compute.ListClustersResponse{
108+
Clusters: []compute.ClusterDetails{
109+
{
110+
ClusterId: "abc-id",
111+
ClusterName: "first shared",
112+
DataSecurityMode: compute.DataSecurityModeUserIsolation,
113+
SparkVersion: "12.2.x-whatever",
114+
State: compute.StateRunning,
115+
},
116+
},
117+
},
118+
},
119+
{
120+
Method: "GET",
121+
Resource: "/api/2.0/preview/scim/v2/Me",
122+
Response: iam.User{
123+
UserName: "serge",
124+
},
125+
},
126+
{
127+
Method: "GET",
128+
Resource: "/api/2.0/clusters/spark-versions",
129+
Response: compute.GetSparkVersionsResponse{
130+
Versions: []compute.SparkVersion{
131+
{
132+
Key: "14.5.x-whatever",
133+
Name: "14.5 (Awesome)",
134+
},
135+
},
136+
},
137+
},
138+
}.Config(t)
139+
defer server.Close()
140+
w := databricks.Must(databricks.NewWorkspaceClient((*databricks.Config)(cfg)))
141+
142+
ctx := context.Background()
143+
ctx = cmdio.InContext(ctx, cmdio.NewIO(flags.OutputText, &bytes.Buffer{}, &bytes.Buffer{}, &bytes.Buffer{}, "..."))
144+
_, err := AskForCluster(ctx, w, WithDatabricksConnect("13.1"))
145+
require.Equal(t, ErrNoCompatibleClusters, err)
146+
}

0 commit comments

Comments
 (0)