Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion libs/auth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"time"

"github.com/databricks/cli/libs/auth/cache"
"github.com/databricks/cli/libs/databrickscfg"
"github.com/databricks/databricks-sdk-go/config"
"github.com/databricks/databricks-sdk-go/retries"
"github.com/pkg/browser"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -95,6 +97,16 @@ func (a *PersistentAuth) Load(ctx context.Context) (*oauth2.Token, error) {
return refreshed, nil
}

func (a *PersistentAuth) profileName() string {
// TODO: get profile name from interactive input
if a.AccountID != "" {
return fmt.Sprintf("ACCOUNT-%s", a.AccountID)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why all caps?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

account profiles should get special treatment. i'll add ACCOUNT profile to denote the default account in the subsequent PRs

}
host := strings.TrimPrefix(a.Host, "https://")
split := strings.Split(host, ".")
return split[0]
}

func (a *PersistentAuth) Challenge(ctx context.Context) error {
err := a.init(ctx)
if err != nil {
Expand All @@ -120,7 +132,12 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error {
if err != nil {
return fmt.Errorf("store: %w", err)
}
return nil
return databrickscfg.SaveToProfile(ctx, &config.Config{
Host: a.Host,
AccountID: a.AccountID,
AuthType: "databricks-cli",
Profile: a.profileName(),
})
}

func (a *PersistentAuth) init(ctx context.Context) error {
Expand Down
88 changes: 54 additions & 34 deletions libs/databrickscfg/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package databrickscfg

import (
"context"
"errors"
"fmt"
"os"
"strings"
Expand All @@ -13,6 +14,43 @@ import (

var ResolveProfileFromHost = profileFromHostLoader{}

var errNoMatchingProfiles = errors.New("no matching config profiles found")

type errMultipleProfiles []string

func (e errMultipleProfiles) Error() string {
return fmt.Sprintf("multiple profiles matched: %s", strings.Join(e, ", "))
}

func findMatchingProfile(configFile *config.File, matcher func(*ini.Section) bool) (*ini.Section, error) {
// Look for sections in the configuration file that match the configured host.
var matching []*ini.Section
for _, section := range configFile.Sections() {
if !matcher(section) {
continue
}
matching = append(matching, section)
}

// If there are no matching sections, we don't do anything.
if len(matching) == 0 {
return nil, errNoMatchingProfiles
}

// If there are multiple matching sections, let the user know it is impossible
// to unambiguously select a profile to use.
if len(matching) > 1 {
var names errMultipleProfiles
for _, section := range matching {
names = append(names, section.Name())
}

return nil, names
}

return matching[0], nil
}
Comment thread
nfx marked this conversation as resolved.

type profileFromHostLoader struct{}

func (l profileFromHostLoader) Name() string {
Expand All @@ -27,63 +65,45 @@ func (l profileFromHostLoader) Configure(cfg *config.Config) error {
return nil
}

ctx := context.Background()
configFile, err := config.LoadFile(cfg.ConfigFile)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("cannot parse config file: %w", err)
}

// Normalized version of the configured host.
host := normalizeHost(cfg.Host)

// Look for sections in the configuration file that match the configured host.
var matching []*ini.Section
for _, section := range configFile.Sections() {
key, err := section.GetKey("host")
match, err := findMatchingProfile(configFile, func(s *ini.Section) bool {
key, err := s.GetKey("host")
if err != nil {
log.Tracef(context.Background(), "section %s: %s", section.Name(), err)
continue
log.Tracef(ctx, "section %s: %s", s.Name(), err)
return false
}

// Ignore this section if the normalized host doesn't match.
if normalizeHost(key.Value()) != host {
continue
}

matching = append(matching, section)
}

// If there are no matching sections, we don't do anything.
if len(matching) == 0 {
// Check if this section matches the normalized host
return normalizeHost(key.Value()) == host
})
if err == errNoMatchingProfiles {
return nil
}

// If there are multiple matching sections, let the user know it is impossible
// to unambiguously select a profile to use.
if len(matching) > 1 {
var names []string
for _, section := range matching {
names = append(names, section.Name())
}

if err, ok := err.(errMultipleProfiles); ok {
return fmt.Errorf(
"multiple profiles for host %s (%s): please set DATABRICKS_CONFIG_PROFILE to specify one",
host,
strings.Join(names, ", "),
)
"%s: %w: please set DATABRICKS_CONFIG_PROFILE to specify one",
host, err)
}
if err != nil {
return err
}

match := matching[0]
log.Debugf(context.Background(), "Loading profile %s because of host match", match.Name())
log.Debugf(ctx, "Loading profile %s because of host match", match.Name())
err = config.ConfigAttributes.ResolveFromStringMap(cfg, match.KeysHash())
if err != nil {
return fmt.Errorf("%s %s profile: %w", configFile.Path(), match.Name(), err)
}

return nil

}

func (l profileFromHostLoader) isAnyAuthConfigured(cfg *config.Config) bool {
Expand Down
2 changes: 1 addition & 1 deletion libs/databrickscfg/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,5 @@ func TestLoaderErrorsOnMultipleMatches(t *testing.T) {

err := cfg.EnsureResolved()
assert.Error(t, err)
assert.ErrorContains(t, err, "multiple profiles for host https://foo (foo1, foo2): ")
assert.ErrorContains(t, err, "https://foo: multiple profiles matched: foo1, foo2")
}
122 changes: 122 additions & 0 deletions libs/databrickscfg/ops.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package databrickscfg

import (
"context"
"fmt"
"os"
"strings"

"github.com/databricks/cli/libs/log"
"github.com/databricks/databricks-sdk-go/config"
"gopkg.in/ini.v1"
)

const fileMode = 0o600

func loadOrCreateConfigFile(filename string) (*config.File, error) {
if filename == "" {
filename = "~/.databrickscfg"
}
// Expand ~ to home directory, as we need a deterministic name for os.OpenFile
// to work in the cases when ~/.databrickscfg does not exist yet
if strings.HasPrefix(filename, "~") {
homedir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("cannot find homedir: %w", err)
}
filename = fmt.Sprintf("%s%s", homedir, filename[1:])
}
configFile, err := config.LoadFile(filename)
if err != nil && os.IsNotExist(err) {
file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fileMode)
if err != nil {
return nil, fmt.Errorf("create %s: %w", filename, err)
}
defer file.Close()
configFile, err = config.LoadFile(filename)
if err != nil {
return nil, fmt.Errorf("load created %s: %w", filename, err)
}
} else if err != nil {
return nil, fmt.Errorf("parse %s: %w", filename, err)
}
return configFile, nil
}

func matchOrCreateSection(ctx context.Context, configFile *config.File, cfg *config.Config) (*ini.Section, error) {
section, err := findMatchingProfile(configFile, func(s *ini.Section) bool {
if cfg.Profile == s.Name() {
return true
}
raw := s.KeysHash()
if cfg.AccountID != "" {
// here we rely on map zerovals for matching with accounts:
// if profile has no account id, the raw["account_id"] will be empty
return cfg.AccountID == raw["account_id"]
}
if cfg.Host == "" {
return false
}
host, ok := raw["host"]
if !ok {
log.Tracef(ctx, "section %s: no host", s.Name())
return false
}
Comment thread
nfx marked this conversation as resolved.
// Check if this section matches the normalized host
return normalizeHost(host) == normalizeHost(cfg.Host)
})
if err == errNoMatchingProfiles {
section, err = configFile.NewSection(cfg.Profile)
if err != nil {
return nil, fmt.Errorf("cannot create new profile: %w", err)
}
} else if err != nil {
return nil, err
}
return section, nil
}

func SaveToProfile(ctx context.Context, cfg *config.Config) error {
configFile, err := loadOrCreateConfigFile(cfg.ConfigFile)
if err != nil {
return err
}

section, err := matchOrCreateSection(ctx, configFile, cfg)
if err != nil {
return err
}

// zeroval profile name before adding it to a section
cfg.Profile = ""
cfg.ConfigFile = ""

// clear old keys in case we're overriding the section
for _, oldKey := range section.KeyStrings() {
section.DeleteKey(oldKey)
}

for _, attr := range config.ConfigAttributes {
if attr.IsZero(cfg) {
continue
}
key := section.Key(attr.Name)
key.SetValue(attr.GetString(cfg))
}

orig, backupErr := os.ReadFile(configFile.Path())
Comment thread
nfx marked this conversation as resolved.
if len(orig) > 0 && backupErr == nil {
log.Infof(ctx, "Backing up in %s.bak", configFile.Path())
err = os.WriteFile(configFile.Path()+".bak", orig, fileMode)
if err != nil {
return fmt.Errorf("backup: %w", err)
}
log.Infof(ctx, "Overwriting %s", configFile.Path())
} else if backupErr != nil {
log.Warnf(ctx, "Failed to backup %s: %v. Proceeding to save",
configFile.Path(), backupErr)
} else {
log.Infof(ctx, "Saving %s", configFile.Path())
}
return configFile.SaveTo(configFile.Path())
}
Loading