diff --git a/internal/filer_test.go b/internal/filer_test.go index 0dda1d1bfe2..69598cb866b 100644 --- a/internal/filer_test.go +++ b/internal/filer_test.go @@ -15,7 +15,6 @@ import ( "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/apierr" "github.com/databricks/databricks-sdk-go/service/files" - "github.com/databricks/databricks-sdk-go/service/workspace" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -159,35 +158,17 @@ func runFilerReadDirTest(t *testing.T, ctx context.Context, f filer.Filer) { assert.Len(t, entries, 1) assert.Equal(t, "c", entries[0].Name()) assert.True(t, entries[0].IsDir()) -} -func temporaryWorkspaceDir(t *testing.T, w *databricks.WorkspaceClient) string { - ctx := context.Background() - me, err := w.CurrentUser.Me(ctx) - require.NoError(t, err) + // Expect an error trying to call ReadDir on a file + _, err = f.ReadDir(ctx, "/hello.txt") + assert.ErrorIs(t, err, filer.ErrNotADirectory) - path := fmt.Sprintf("/Users/%s/%s", me.UserName, RandomName("integration-test-filer-wsfs-")) - - // Ensure directory exists, but doesn't exist YET! - // Otherwise we could inadvertently remove a directory that already exists on cleanup. - t.Logf("mkdir %s", path) - err = w.Workspace.MkdirsByPath(ctx, path) + // Expect 0 entries for an empty directory + err = f.Mkdir(ctx, "empty-dir") require.NoError(t, err) - - // Remove test directory on test completion. - t.Cleanup(func() { - t.Logf("rm -rf %s", path) - err := w.Workspace.Delete(ctx, workspace.Delete{ - Path: path, - Recursive: true, - }) - if err == nil || apierr.IsMissing(err) { - return - } - t.Logf("unable to remove temporary workspace directory %s: %#v", path, err) - }) - - return path + entries, err = f.ReadDir(ctx, "empty-dir") + assert.NoError(t, err) + assert.Len(t, entries, 0) } func setupWorkspaceFilesTest(t *testing.T) (context.Context, filer.Filer) { diff --git a/internal/helpers.go b/internal/helpers.go index b51d005b27e..63a76c0bbe4 100644 --- a/internal/helpers.go +++ b/internal/helpers.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "fmt" + "io" "math/rand" "os" "path/filepath" @@ -14,6 +15,11 @@ import ( "github.com/databricks/cli/cmd/root" _ "github.com/databricks/cli/cmd/version" + "github.com/databricks/cli/libs/filer" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" _ "github.com/databricks/cli/cmd/workspace" @@ -176,3 +182,52 @@ func writeFile(t *testing.T, name string, body string) string { f.Close() return f.Name() } + +func temporaryWorkspaceDir(t *testing.T, w *databricks.WorkspaceClient) string { + ctx := context.Background() + me, err := w.CurrentUser.Me(ctx) + require.NoError(t, err) + + path := fmt.Sprintf("/Users/%s/%s", me.UserName, RandomName("integration-test-wsfs-")) + + // Ensure directory exists, but doesn't exist YET! + // Otherwise we could inadvertently remove a directory that already exists on cleanup. + t.Logf("mkdir %s", path) + err = w.Workspace.MkdirsByPath(ctx, path) + require.NoError(t, err) + + // Remove test directory on test completion. + t.Cleanup(func() { + t.Logf("rm -rf %s", path) + err := w.Workspace.Delete(ctx, workspace.Delete{ + Path: path, + Recursive: true, + }) + if err == nil || apierr.IsMissing(err) { + return + } + t.Logf("unable to remove temporary workspace directory %s: %#v", path, err) + }) + + return path +} + +func assertFileContains(t *testing.T, ctx context.Context, f filer.Filer, name, contents string) { + r, err := f.Read(ctx, name) + require.NoError(t, err) + + var b bytes.Buffer + _, err = io.Copy(&b, r) + require.NoError(t, err) + + assert.Contains(t, b.String(), contents) +} + +func assertNotebookExists(t *testing.T, ctx context.Context, w *databricks.WorkspaceClient, path string) { + info, err := w.Workspace.ListAll(ctx, workspace.ListWorkspaceRequest{ + Path: path, + }) + require.NoError(t, err) + assert.Len(t, info, 1) + assert.Equal(t, info[0].ObjectType, workspace.ObjectTypeNotebook) +} diff --git a/internal/repofiles_test.go b/internal/repofiles_test.go new file mode 100644 index 00000000000..6ffbd30e83b --- /dev/null +++ b/internal/repofiles_test.go @@ -0,0 +1,251 @@ +package internal + +import ( + "context" + "io/fs" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/databricks/cli/libs/filer" + "github.com/databricks/cli/libs/sync/repofiles" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type repofilesTestHelper struct { + w *databricks.WorkspaceClient + f filer.Filer + ctx context.Context + t *testing.T + + localRoot string + remoteRoot string +} + +func setupRepofilesTestHelper(t *testing.T, ctx context.Context) *repofilesTestHelper { + // t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV")) + + w, err := databricks.NewWorkspaceClient() + require.NoError(t, err) + + // initialize client + wsfsTmpDir := temporaryWorkspaceDir(t, w) + localTmpDir := t.TempDir() + + require.NoError(t, err) + f, err := filer.NewWorkspaceFilesClient(w, wsfsTmpDir) + require.NoError(t, err) + + return &repofilesTestHelper{ + w: w, + f: f, + ctx: ctx, + t: t, + + localRoot: localTmpDir, + remoteRoot: wsfsTmpDir, + } +} + +func (h *repofilesTestHelper) createLocalFile(name string, content string) { + absPath := filepath.Join(h.localRoot, name) + err := os.MkdirAll(filepath.Dir(absPath), os.ModePerm) + require.NoError(h.t, err) + err = os.WriteFile(absPath, []byte(content), os.ModePerm) + require.NoError(h.t, err) +} + +func (h *repofilesTestHelper) createRemoteFile(name string, content string) { + h.f.Write(h.ctx, name, strings.NewReader(content), filer.CreateParentDirectories) +} + +func (h *repofilesTestHelper) createRemoteDirectory(name string) { + h.f.Mkdir(h.ctx, name) +} + +func (h *repofilesTestHelper) assertRemoteFileContent(name string, content string) { + assertFileContains(h.t, h.ctx, h.f, name, content) +} + +func (h *repofilesTestHelper) assertRemoteFileType(name string, fileType workspace.ObjectType) { + info, err := h.f.Stat(h.ctx, name) + require.NoError(h.t, err) + + objectInfo := info.Sys().(workspace.ObjectInfo) + assert.Equal(h.t, fileType, objectInfo.ObjectType) +} + +func TestRepoFilesPutFile(t *testing.T) { + ctx := context.Background() + helper := setupRepofilesTestHelper(t, ctx) + + r, err := repofiles.Create(helper.remoteRoot, helper.localRoot, helper.w, &repofiles.RepoFileOptions{ + OverwriteIfExists: true, + }) + require.NoError(t, err) + + // create local file + helper.createLocalFile("foo.txt", "hello, world") + err = r.PutFile(ctx, "foo.txt") + require.NoError(t, err) + + // Expect PUT to succeed + helper.assertRemoteFileContent("foo.txt", "hello, world") +} + +func TestRepoFilesPutFileOverwritesNotebook(t *testing.T) { + ctx := context.Background() + helper := setupRepofilesTestHelper(t, ctx) + + r, err := repofiles.Create(helper.remoteRoot, helper.localRoot, helper.w, &repofiles.RepoFileOptions{ + OverwriteIfExists: true, + }) + require.NoError(t, err) + + // Create notebook in workspace + helper.createRemoteFile("foo.py", "#Databricks notebook source\nprint(1)") + helper.assertRemoteFileType("foo", workspace.ObjectTypeNotebook) + + // Put file and assert file PUT succeeded + helper.createLocalFile("foo", "this file will overwrite the notebook") + err = r.PutFile(ctx, "foo") + assert.NoError(t, err) + helper.assertRemoteFileContent("foo", "this file will overwrite the notebook") + helper.assertRemoteFileType("foo", workspace.ObjectTypeFile) +} + +func TestRepoFilesPutFileOverwritesEmptyDirectoryTree(t *testing.T) { + ctx := context.Background() + helper := setupRepofilesTestHelper(t, ctx) + + r, err := repofiles.Create(helper.remoteRoot, helper.localRoot, helper.w, &repofiles.RepoFileOptions{ + OverwriteIfExists: true, + }) + require.NoError(t, err) + + // create empty remote directory tree + helper.createRemoteDirectory("foo/a/b/c") + helper.createRemoteDirectory("foo/a/b/d/e") + helper.createRemoteDirectory("foo/f/g/i") + + // assert directory tree is created + helper.assertRemoteFileType("foo", workspace.ObjectTypeDirectory) + helper.assertRemoteFileType("foo/a/b/c", workspace.ObjectTypeDirectory) + helper.assertRemoteFileType("foo/f/g/i", workspace.ObjectTypeDirectory) + helper.assertRemoteFileType("foo/a/b/d/e", workspace.ObjectTypeDirectory) + + // Create local file and PUT it into the workspace + helper.createLocalFile("foo", "hello, world") + err = r.PutFile(ctx, "foo") + require.NoError(t, err) + helper.assertRemoteFileContent("foo", "hello, world") + helper.assertRemoteFileType("foo", workspace.ObjectTypeFile) +} + +func TestRepoFilesPutFileInDirOverwritesExistingNotebook(t *testing.T) { + // TODO: Skipping this test for now since the workspace-files import API has a + // bug and does not return the error message we need + t.SkipNow() + + ctx := context.Background() + helper := setupRepofilesTestHelper(t, ctx) + + r, err := repofiles.Create(helper.remoteRoot, helper.localRoot, helper.w, &repofiles.RepoFileOptions{ + OverwriteIfExists: true, + }) + require.NoError(t, err) + + // create remote notebook + helper.createRemoteFile("foo.py", "#Databricks notebook source\nprint(1)") + helper.assertRemoteFileType("foo", workspace.ObjectTypeNotebook) + + // create local file and PUT it in the workspace + helper.createLocalFile("foo/hello.txt", "just a file") + err = r.PutFile(ctx, "foo/hello.txt") + require.NoError(t, err) + + // Assert PUT succeeeded + helper.assertRemoteFileType("foo", workspace.ObjectTypeDirectory) + helper.assertRemoteFileContent("foo/bar.txt", "just a file") +} + +func TestRepoFilesPutFileWithoutOverwrite(t *testing.T) { + ctx := context.Background() + helper := setupRepofilesTestHelper(t, ctx) + + r, err := repofiles.Create(helper.remoteRoot, helper.localRoot, helper.w, &repofiles.RepoFileOptions{ + OverwriteIfExists: false, + }) + require.NoError(t, err) + + // create local file + helper.createLocalFile("foo.txt", "hello, world") + err = r.PutFile(ctx, "foo.txt") + require.NoError(t, err) + + // Expect PUT to succeed + helper.assertRemoteFileContent("foo.txt", "hello, world") +} + +func TestRepoFilesPutFileWithoutOverwriteFails(t *testing.T) { + ctx := context.Background() + helper := setupRepofilesTestHelper(t, ctx) + + r, err := repofiles.Create(helper.remoteRoot, helper.localRoot, helper.w, &repofiles.RepoFileOptions{ + OverwriteIfExists: false, + }) + require.NoError(t, err) + + // create remote file + helper.createRemoteFile("foo.txt", "this file already exists in the workspace") + + // create local file + helper.createLocalFile("foo.txt", "this file will attempt to overwrite the workspace file and fail") + + // assert overwrite fails + err = r.PutFile(ctx, "foo.txt") + assert.ErrorIs(t, err, fs.ErrExist) +} + +func TestRepoFilesPutFileWithoutOverwriteFailsIfDirectoryExists(t *testing.T) { + ctx := context.Background() + helper := setupRepofilesTestHelper(t, ctx) + + r, err := repofiles.Create(helper.remoteRoot, helper.localRoot, helper.w, &repofiles.RepoFileOptions{ + OverwriteIfExists: false, + }) + require.NoError(t, err) + + helper.createRemoteDirectory("foo") + + // create local file + helper.createLocalFile("foo", "hello, world") + err = r.PutFile(ctx, "foo") + + // Assert PUT failed because file already exists + assert.ErrorIs(t, err, fs.ErrExist) +} + +func TestRepoFilesPutFileWithoutOverwriteFailsIfNotebookExists(t *testing.T) { + ctx := context.Background() + helper := setupRepofilesTestHelper(t, ctx) + + r, err := repofiles.Create(helper.remoteRoot, helper.localRoot, helper.w, &repofiles.RepoFileOptions{ + OverwriteIfExists: false, + }) + require.NoError(t, err) + + // create remote notebook + helper.createRemoteFile("foo.py", "#Databricks notebook source\nprint(1)") + + // create local file + helper.createLocalFile("foo", "hello, world") + err = r.PutFile(ctx, "foo") + + // Assert PUT failed because file already exists + assert.ErrorIs(t, err, fs.ErrExist) +} diff --git a/libs/filer/dbfs_client.go b/libs/filer/dbfs_client.go index c86a80b1e15..d4397686994 100644 --- a/libs/filer/dbfs_client.go +++ b/libs/filer/dbfs_client.go @@ -59,7 +59,7 @@ func (info dbfsFileInfo) IsDir() bool { } func (info dbfsFileInfo) Sys() any { - return nil + return info.fi } // DbfsClient implements the [Filer] interface for the DBFS backend. @@ -222,6 +222,10 @@ func (w *DbfsClient) ReadDir(ctx context.Context, name string) ([]fs.DirEntry, e return nil, err } + if len(res.Files) == 1 && res.Files[0].Path == absPath { + return nil, NotADirectory{absPath} + } + info := make([]fs.DirEntry, len(res.Files)) for i, v := range res.Files { info[i] = dbfsDirEntry{dbfsFileInfo: dbfsFileInfo{fi: v}} diff --git a/libs/filer/filer.go b/libs/filer/filer.go index ff01ea79816..1525aba3a0e 100644 --- a/libs/filer/filer.go +++ b/libs/filer/filer.go @@ -2,6 +2,7 @@ package filer import ( "context" + "errors" "fmt" "io" "io/fs" @@ -50,6 +51,20 @@ func (err NoSuchDirectoryError) Is(other error) bool { return other == fs.ErrNotExist } +var ErrNotADirectory = errors.New("not a directory") + +type NotADirectory struct { + path string +} + +func (err NotADirectory) Error() string { + return fmt.Sprintf("%s is not a directory", err.path) +} + +func (err NotADirectory) Is(other error) bool { + return other == ErrNotADirectory +} + // Filer is used to access files in a workspace. // It has implementations for accessing files in WSFS and in DBFS. type Filer interface { diff --git a/libs/filer/workspace_files_client.go b/libs/filer/workspace_files_client.go index 967f9a1de5e..50ccfed7dfa 100644 --- a/libs/filer/workspace_files_client.go +++ b/libs/filer/workspace_files_client.go @@ -65,7 +65,7 @@ func (info wsfsFileInfo) IsDir() bool { } func (info wsfsFileInfo) Sys() any { - return nil + return info.oi } // WorkspaceFilesClient implements the files-in-workspace API. @@ -222,6 +222,11 @@ func (w *WorkspaceFilesClient) ReadDir(ctx context.Context, name string) ([]fs.D objects, err := w.workspaceClient.Workspace.ListAll(ctx, workspace.ListWorkspaceRequest{ Path: absPath, }) + + if len(objects) == 1 && objects[0].Path == absPath { + return nil, NotADirectory{absPath} + } + if err != nil { // If we got an API error we deal with it below. var aerr *apierr.APIError diff --git a/libs/sync/repofiles/repofiles.go b/libs/sync/repofiles/repofiles.go index 8fcabc113ec..f856793af6a 100644 --- a/libs/sync/repofiles/repofiles.go +++ b/libs/sync/repofiles/repofiles.go @@ -1,36 +1,52 @@ package repofiles import ( + "bytes" "context" "errors" "fmt" - "net/http" - "net/url" "os" "path" "path/filepath" "strings" + "github.com/databricks/cli/libs/filer" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/apierr" - "github.com/databricks/databricks-sdk-go/client" "github.com/databricks/databricks-sdk-go/service/workspace" ) +type RepoFileOptions struct { + OverwriteIfExists bool +} + // RepoFiles wraps reading and writing into a remote repo with safeguards to prevent // accidental deletion of repos and more robust methods to overwrite workspace files type RepoFiles struct { + *RepoFileOptions + repoRoot string localRoot string workspaceClient *databricks.WorkspaceClient + f filer.Filer } -func Create(repoRoot, localRoot string, workspaceClient *databricks.WorkspaceClient) *RepoFiles { +func Create(repoRoot, localRoot string, w *databricks.WorkspaceClient, opts *RepoFileOptions) (*RepoFiles, error) { + // override default timeout to support uploading larger files + w.Config.HTTPTimeoutSeconds = 600 + + // create filer to interact with WSFS + f, err := filer.NewWorkspaceFilesClient(w, repoRoot) + if err != nil { + return nil, err + } return &RepoFiles{ repoRoot: repoRoot, localRoot: localRoot, - workspaceClient: workspaceClient, - } + workspaceClient: w, + RepoFileOptions: opts, + f: f, + }, nil } func (r *RepoFiles) remotePath(relativePath string) (string, error) { @@ -52,36 +68,25 @@ func (r *RepoFiles) readLocal(relativePath string) ([]byte, error) { } func (r *RepoFiles) writeRemote(ctx context.Context, relativePath string, content []byte) error { - apiClientConfig := r.workspaceClient.Config - apiClientConfig.HTTPTimeoutSeconds = 600 - apiClient, err := client.New(apiClientConfig) - if err != nil { - return err + if !r.OverwriteIfExists { + return r.f.Write(ctx, relativePath, bytes.NewReader(content), filer.CreateParentDirectories) } - remotePath, err := r.remotePath(relativePath) - if err != nil { - return err - } - escapedPath := url.PathEscape(strings.TrimLeft(remotePath, "/")) - apiPath := fmt.Sprintf("/api/2.0/workspace-files/import-file/%s?overwrite=true", escapedPath) - - err = apiClient.Do(ctx, http.MethodPost, apiPath, content, nil) - - // Handling some edge cases when an upload might fail - // - // We cannot do more precise error scoping here because the API does not - // provide descriptive errors yet - // - // TODO: narrow down the error condition scope of this "if" block to only - // trigger for the specific edge cases instead of all errors once the API - // implements them + + err := r.f.Write(ctx, relativePath, bytes.NewReader(content), filer.CreateParentDirectories, filer.OverwriteIfExists) + + // TODO(pietern): Use the new FS interface to avoid needing to make a recursive + // delete call here. This call is dangerous if err != nil { // Delete any artifact files incase non overwriteable by the current file // type and thus are failing the PUT request. // files, folders and notebooks might not have been cleaned up and they // can't overwrite each other. If a folder `foo` exists, then attempts to // PUT a file `foo` will fail - err := r.workspaceClient.Workspace.Delete(ctx, + remotePath, err := r.remotePath(relativePath) + if err != nil { + return err + } + err = r.workspaceClient.Workspace.Delete(ctx, workspace.Delete{ Path: remotePath, Recursive: true, @@ -96,33 +101,15 @@ func (r *RepoFiles) writeRemote(ctx context.Context, relativePath string, conten return err } - // Mkdir parent dirs incase they are what's causing the PUT request to - // fail - err = r.workspaceClient.Workspace.MkdirsByPath(ctx, path.Dir(remotePath)) - if err != nil { - return fmt.Errorf("could not mkdir to put file: %s", err) - } - - // Attempt to upload file again after cleanup/setup - err = apiClient.Do(ctx, http.MethodPost, apiPath, content, nil) - if err != nil { - return err - } + // Attempt to write the file again, this time without the CreateParentDirectories and + // OverwriteIfExists flags + return r.f.Write(ctx, relativePath, bytes.NewReader(content)) } return nil } func (r *RepoFiles) deleteRemote(ctx context.Context, relativePath string) error { - remotePath, err := r.remotePath(relativePath) - if err != nil { - return err - } - return r.workspaceClient.Workspace.Delete(ctx, - workspace.Delete{ - Path: remotePath, - Recursive: false, - }, - ) + return r.f.Delete(ctx, relativePath) } // The API calls for a python script foo.py would be @@ -154,6 +141,3 @@ func (r *RepoFiles) DeleteFile(ctx context.Context, relativePath string) error { } return nil } - -// TODO: write integration tests for all non happy path cases that rely on -// specific behaviour of the workspace apis diff --git a/libs/sync/repofiles/repofiles_test.go b/libs/sync/repofiles/repofiles_test.go index 2a881d90d06..dc9abbcddf4 100644 --- a/libs/sync/repofiles/repofiles_test.go +++ b/libs/sync/repofiles/repofiles_test.go @@ -6,11 +6,13 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRepoFilesRemotePath(t *testing.T) { repoRoot := "/Repos/doraemon/bar" - repoFiles := Create(repoRoot, "/doraemon/foo/bar", nil) + repoFiles, err := Create(repoRoot, "/doraemon/foo/bar", nil, nil) + require.NoError(t, err) remotePath, err := repoFiles.remotePath("a/b/c") assert.NoError(t, err) @@ -81,7 +83,8 @@ func TestRepoReadLocal(t *testing.T) { err := os.WriteFile(helloPath, []byte("my name is doraemon :P"), os.ModePerm) assert.NoError(t, err) - repoFiles := Create("/Repos/doraemon/bar", tempDir, nil) + repoFiles, err := Create("/Repos/doraemon/bar", tempDir, nil, nil) + require.NoError(t, err) bytes, err := repoFiles.readLocal("./a/../hello.txt") assert.NoError(t, err) assert.Equal(t, "my name is doraemon :P", string(bytes)) diff --git a/libs/sync/sync.go b/libs/sync/sync.go index 54d0624e77c..65bad57f08c 100644 --- a/libs/sync/sync.go +++ b/libs/sync/sync.go @@ -77,7 +77,10 @@ func New(ctx context.Context, opts SyncOptions) (*Sync, error) { } } - repoFiles := repofiles.Create(opts.RemotePath, opts.LocalPath, opts.WorkspaceClient) + repoFiles, err := repofiles.Create(opts.RemotePath, opts.LocalPath, opts.WorkspaceClient, &repofiles.RepoFileOptions{OverwriteIfExists: true}) + if err != nil { + return nil, err + } return &Sync{ SyncOptions: &opts,