diff --git a/libs/auth/oauth.go b/libs/auth/oauth.go index af4b5e09658..1171103cbbd 100644 --- a/libs/auth/oauth.go +++ b/libs/auth/oauth.go @@ -16,6 +16,8 @@ import ( "time" "github.com/databricks/cli/libs/auth/cache" + "github.com/databricks/cli/libs/databrickscfg" + "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/retries" "github.com/pkg/browser" "golang.org/x/oauth2" @@ -95,6 +97,16 @@ func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) { return refreshed, nil } +func (a *PersistentAuth) profileName() string { + // TODO: get profile name from interactive input + if a.AccountID != "" { + return fmt.Sprintf("ACCOUNT-%s", a.AccountID) + } + host := strings.TrimPrefix(a.Host, "https://") + split := strings.Split(host, ".") + return split[0] +} + func (a *PersistentAuth) Challenge(ctx context.Context) error { err := a.init(ctx) if err != nil { @@ -120,7 +132,12 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error { if err != nil { return fmt.Errorf("store: %w", err) } - return nil + return databrickscfg.SaveToProfile(ctx, &config.Config{ + Host: a.Host, + AccountID: a.AccountID, + AuthType: "databricks-cli", + Profile: a.profileName(), + }) } func (a *PersistentAuth) init(ctx context.Context) error { diff --git a/libs/databrickscfg/loader.go b/libs/databrickscfg/loader.go index 087703ff847..8179703a318 100644 --- a/libs/databrickscfg/loader.go +++ b/libs/databrickscfg/loader.go @@ -2,6 +2,7 @@ package databrickscfg import ( "context" + "errors" "fmt" "os" "strings" @@ -13,6 +14,43 @@ import ( var ResolveProfileFromHost = profileFromHostLoader{} +var errNoMatchingProfiles = errors.New("no matching config profiles found") + +type errMultipleProfiles []string + +func (e errMultipleProfiles) Error() string { + return fmt.Sprintf("multiple profiles matched: %s", strings.Join(e, ", ")) +} + +func findMatchingProfile(configFile *config.File, matcher func(*ini.Section) bool) (*ini.Section, error) { + // Look for sections in the configuration file that match the configured host. + var matching []*ini.Section + for _, section := range configFile.Sections() { + if !matcher(section) { + continue + } + matching = append(matching, section) + } + + // If there are no matching sections, we don't do anything. + if len(matching) == 0 { + return nil, errNoMatchingProfiles + } + + // If there are multiple matching sections, let the user know it is impossible + // to unambiguously select a profile to use. + if len(matching) > 1 { + var names errMultipleProfiles + for _, section := range matching { + names = append(names, section.Name()) + } + + return nil, names + } + + return matching[0], nil +} + type profileFromHostLoader struct{} func (l profileFromHostLoader) Name() string { @@ -27,6 +65,7 @@ func (l profileFromHostLoader) Configure(cfg *config.Config) error { return nil } + ctx := context.Background() configFile, err := config.LoadFile(cfg.ConfigFile) if err != nil { if os.IsNotExist(err) { @@ -34,56 +73,37 @@ func (l profileFromHostLoader) Configure(cfg *config.Config) error { } return fmt.Errorf("cannot parse config file: %w", err) } - // Normalized version of the configured host. host := normalizeHost(cfg.Host) - - // Look for sections in the configuration file that match the configured host. - var matching []*ini.Section - for _, section := range configFile.Sections() { - key, err := section.GetKey("host") + match, err := findMatchingProfile(configFile, func(s *ini.Section) bool { + key, err := s.GetKey("host") if err != nil { - log.Tracef(context.Background(), "section %s: %s", section.Name(), err) - continue + log.Tracef(ctx, "section %s: %s", s.Name(), err) + return false } - // Ignore this section if the normalized host doesn't match. - if normalizeHost(key.Value()) != host { - continue - } - - matching = append(matching, section) - } - - // If there are no matching sections, we don't do anything. - if len(matching) == 0 { + // Check if this section matches the normalized host + return normalizeHost(key.Value()) == host + }) + if err == errNoMatchingProfiles { return nil } - - // If there are multiple matching sections, let the user know it is impossible - // to unambiguously select a profile to use. - if len(matching) > 1 { - var names []string - for _, section := range matching { - names = append(names, section.Name()) - } - + if err, ok := err.(errMultipleProfiles); ok { return fmt.Errorf( - "multiple profiles for host %s (%s): please set DATABRICKS_CONFIG_PROFILE to specify one", - host, - strings.Join(names, ", "), - ) + "%s: %w: please set DATABRICKS_CONFIG_PROFILE to specify one", + host, err) + } + if err != nil { + return err } - match := matching[0] - log.Debugf(context.Background(), "Loading profile %s because of host match", match.Name()) + log.Debugf(ctx, "Loading profile %s because of host match", match.Name()) err = config.ConfigAttributes.ResolveFromStringMap(cfg, match.KeysHash()) if err != nil { return fmt.Errorf("%s %s profile: %w", configFile.Path(), match.Name(), err) } return nil - } func (l profileFromHostLoader) isAnyAuthConfigured(cfg *config.Config) bool { diff --git a/libs/databrickscfg/loader_test.go b/libs/databrickscfg/loader_test.go index 59610858c9c..5fa7f7dd2e0 100644 --- a/libs/databrickscfg/loader_test.go +++ b/libs/databrickscfg/loader_test.go @@ -126,5 +126,5 @@ func TestLoaderErrorsOnMultipleMatches(t *testing.T) { err := cfg.EnsureResolved() assert.Error(t, err) - assert.ErrorContains(t, err, "multiple profiles for host https://foo (foo1, foo2): ") + assert.ErrorContains(t, err, "https://foo: multiple profiles matched: foo1, foo2") } diff --git a/libs/databrickscfg/ops.go b/libs/databrickscfg/ops.go new file mode 100644 index 00000000000..8b4e1a5f30a --- /dev/null +++ b/libs/databrickscfg/ops.go @@ -0,0 +1,122 @@ +package databrickscfg + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go/config" + "gopkg.in/ini.v1" +) + +const fileMode = 0o600 + +func loadOrCreateConfigFile(filename string) (*config.File, error) { + if filename == "" { + filename = "~/.databrickscfg" + } + // Expand ~ to home directory, as we need a deterministic name for os.OpenFile + // to work in the cases when ~/.databrickscfg does not exist yet + if strings.HasPrefix(filename, "~") { + homedir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("cannot find homedir: %w", err) + } + filename = fmt.Sprintf("%s%s", homedir, filename[1:]) + } + configFile, err := config.LoadFile(filename) + if err != nil && os.IsNotExist(err) { + file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fileMode) + if err != nil { + return nil, fmt.Errorf("create %s: %w", filename, err) + } + defer file.Close() + configFile, err = config.LoadFile(filename) + if err != nil { + return nil, fmt.Errorf("load created %s: %w", filename, err) + } + } else if err != nil { + return nil, fmt.Errorf("parse %s: %w", filename, err) + } + return configFile, nil +} + +func matchOrCreateSection(ctx context.Context, configFile *config.File, cfg *config.Config) (*ini.Section, error) { + section, err := findMatchingProfile(configFile, func(s *ini.Section) bool { + if cfg.Profile == s.Name() { + return true + } + raw := s.KeysHash() + if cfg.AccountID != "" { + // here we rely on map zerovals for matching with accounts: + // if profile has no account id, the raw["account_id"] will be empty + return cfg.AccountID == raw["account_id"] + } + if cfg.Host == "" { + return false + } + host, ok := raw["host"] + if !ok { + log.Tracef(ctx, "section %s: no host", s.Name()) + return false + } + // Check if this section matches the normalized host + return normalizeHost(host) == normalizeHost(cfg.Host) + }) + if err == errNoMatchingProfiles { + section, err = configFile.NewSection(cfg.Profile) + if err != nil { + return nil, fmt.Errorf("cannot create new profile: %w", err) + } + } else if err != nil { + return nil, err + } + return section, nil +} + +func SaveToProfile(ctx context.Context, cfg *config.Config) error { + configFile, err := loadOrCreateConfigFile(cfg.ConfigFile) + if err != nil { + return err + } + + section, err := matchOrCreateSection(ctx, configFile, cfg) + if err != nil { + return err + } + + // zeroval profile name before adding it to a section + cfg.Profile = "" + cfg.ConfigFile = "" + + // clear old keys in case we're overriding the section + for _, oldKey := range section.KeyStrings() { + section.DeleteKey(oldKey) + } + + for _, attr := range config.ConfigAttributes { + if attr.IsZero(cfg) { + continue + } + key := section.Key(attr.Name) + key.SetValue(attr.GetString(cfg)) + } + + orig, backupErr := os.ReadFile(configFile.Path()) + if len(orig) > 0 && backupErr == nil { + log.Infof(ctx, "Backing up in %s.bak", configFile.Path()) + err = os.WriteFile(configFile.Path()+".bak", orig, fileMode) + if err != nil { + return fmt.Errorf("backup: %w", err) + } + log.Infof(ctx, "Overwriting %s", configFile.Path()) + } else if backupErr != nil { + log.Warnf(ctx, "Failed to backup %s: %v. Proceeding to save", + configFile.Path(), backupErr) + } else { + log.Infof(ctx, "Saving %s", configFile.Path()) + } + return configFile.SaveTo(configFile.Path()) +} diff --git a/libs/databrickscfg/ops_test.go b/libs/databrickscfg/ops_test.go new file mode 100644 index 00000000000..64b4fbadfda --- /dev/null +++ b/libs/databrickscfg/ops_test.go @@ -0,0 +1,192 @@ +package databrickscfg + +import ( + "context" + "path/filepath" + "testing" + + "github.com/databricks/databricks-sdk-go/config" + "github.com/stretchr/testify/assert" +) + +func TestLoadOrCreate(t *testing.T) { + dir := t.TempDir() + + path := filepath.Join(dir, "databrickscfg") + file, err := loadOrCreateConfigFile(path) + assert.NoError(t, err) + assert.NotNil(t, file) + assert.FileExists(t, path) +} + +func TestLoadOrCreate_NotAllowed(t *testing.T) { + path := "/dev/databrickscfg" + file, err := loadOrCreateConfigFile(path) + assert.Error(t, err) + assert.Nil(t, file) + assert.NoFileExists(t, path) +} + +func TestLoadOrCreate_Bad(t *testing.T) { + path := "testdata/badcfg" + file, err := loadOrCreateConfigFile(path) + assert.Error(t, err) + assert.Nil(t, file) +} + +func TestMatchOrCreateSection_Direct(t *testing.T) { + cfg := &config.Config{ + Profile: "query", + } + file, err := loadOrCreateConfigFile("testdata/databrickscfg") + assert.NoError(t, err) + + ctx := context.Background() + section, err := matchOrCreateSection(ctx, file, cfg) + assert.NoError(t, err) + assert.NotNil(t, section) + assert.Equal(t, "query", section.Name()) +} + +func TestMatchOrCreateSection_AccountID(t *testing.T) { + cfg := &config.Config{ + AccountID: "abc", + } + file, err := loadOrCreateConfigFile("testdata/databrickscfg") + assert.NoError(t, err) + + ctx := context.Background() + section, err := matchOrCreateSection(ctx, file, cfg) + assert.NoError(t, err) + assert.NotNil(t, section) + assert.Equal(t, "acc", section.Name()) +} + +func TestMatchOrCreateSection_NormalizeHost(t *testing.T) { + cfg := &config.Config{ + Host: "https://query/?o=abracadabra", + } + file, err := loadOrCreateConfigFile("testdata/databrickscfg") + assert.NoError(t, err) + + ctx := context.Background() + section, err := matchOrCreateSection(ctx, file, cfg) + assert.NoError(t, err) + assert.NotNil(t, section) + assert.Equal(t, "query", section.Name()) +} + +func TestMatchOrCreateSection_NoProfileOrHost(t *testing.T) { + cfg := &config.Config{} + file, err := loadOrCreateConfigFile("testdata/databrickscfg") + assert.NoError(t, err) + + ctx := context.Background() + _, err = matchOrCreateSection(ctx, file, cfg) + assert.EqualError(t, err, "cannot create new profile: empty section name") +} + +func TestMatchOrCreateSection_MultipleProfiles(t *testing.T) { + cfg := &config.Config{ + Host: "https://foo", + } + file, err := loadOrCreateConfigFile("testdata/databrickscfg") + assert.NoError(t, err) + + ctx := context.Background() + _, err = matchOrCreateSection(ctx, file, cfg) + assert.EqualError(t, err, "multiple profiles matched: foo1, foo2") +} + +func TestMatchOrCreateSection_NewProfile(t *testing.T) { + cfg := &config.Config{ + Host: "https://bar", + Profile: "delirium", + } + file, err := loadOrCreateConfigFile("testdata/databrickscfg") + assert.NoError(t, err) + + ctx := context.Background() + section, err := matchOrCreateSection(ctx, file, cfg) + assert.NoError(t, err) + assert.NotNil(t, section) + assert.Equal(t, "delirium", section.Name()) +} + +func TestSaveToProfile_ErrorOnLoad(t *testing.T) { + ctx := context.Background() + err := SaveToProfile(ctx, &config.Config{ + ConfigFile: "testdata/badcfg", + }) + assert.Error(t, err) +} + +func TestSaveToProfile_ErrorOnMatch(t *testing.T) { + ctx := context.Background() + err := SaveToProfile(ctx, &config.Config{ + Host: "https://foo", + }) + assert.Error(t, err) +} + +func TestSaveToProfile_NewFile(t *testing.T) { + ctx := context.Background() + path := filepath.Join(t.TempDir(), "databrickscfg") + + err := SaveToProfile(ctx, &config.Config{ + ConfigFile: path, + Profile: "abc", + Host: "https://foo", + Token: "xyz", + }) + assert.NoError(t, err) + assert.NoFileExists(t, path+".bak") +} + +func TestSaveToProfile_ClearingPreviousProfile(t *testing.T) { + ctx := context.Background() + path := filepath.Join(t.TempDir(), "databrickscfg") + + err := SaveToProfile(ctx, &config.Config{ + ConfigFile: path, + Profile: "abc", + Host: "https://foo", + Token: "xyz", + }) + assert.NoError(t, err) + + err = SaveToProfile(ctx, &config.Config{ + ConfigFile: path, + Profile: "bcd", + Host: "https://bar", + Token: "zyx", + }) + assert.NoError(t, err) + assert.FileExists(t, path+".bak") + + err = SaveToProfile(ctx, &config.Config{ + ConfigFile: path, + Host: "https://foo", + AuthType: "databricks-cli", + }) + assert.NoError(t, err) + + file, err := loadOrCreateConfigFile(path) + assert.NoError(t, err) + + assert.Len(t, file.Sections(), 3) + assert.True(t, file.HasSection("DEFAULT")) + assert.True(t, file.HasSection("bcd")) + assert.True(t, file.HasSection("bcd")) + + dlft, err := file.GetSection("DEFAULT") + assert.NoError(t, err) + assert.Len(t, dlft.KeysHash(), 0) + + abc, err := file.GetSection("abc") + assert.NoError(t, err) + raw := abc.KeysHash() + assert.Len(t, raw, 2) + assert.Equal(t, "https://foo", raw["host"]) + assert.Equal(t, "databricks-cli", raw["auth_type"]) +} diff --git a/libs/databrickscfg/testdata/databrickscfg b/libs/databrickscfg/testdata/databrickscfg index ad81933e48f..ba045c6c284 100644 --- a/libs/databrickscfg/testdata/databrickscfg +++ b/libs/databrickscfg/testdata/databrickscfg @@ -14,6 +14,10 @@ token = query host = https://foo token = foo1 +[acc] +host = https://accounts.cloud.databricks.com +account_id = abc + # Duplicate entry for https://foo [foo2] host = https://foo