Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5294,6 +5294,79 @@ async def test_delete_rag_file_flattened_error_async():
)


@pytest.mark.parametrize(
"request_type",
[
vertex_rag_data_service.BatchCreateRagDataSchemasRequest,
dict,
],
)
def test_batch_create_rag_data_schemas(request_type, transport: str = "grpc"):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
)

# Everything is optional in proto3 as far as the runtime is concerned,
# and we are mocking out the actual API, so just send an empty request.
request = request_type()

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.batch_create_rag_data_schemas), "__call__"
) as call:
# Designate an appropriate return value for the call.
call.return_value = operations_pb2.Operation(name="operations/spam")
response = client.batch_create_rag_data_schemas(request)

# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
request = vertex_rag_data_service.BatchCreateRagDataSchemasRequest()
assert args[0] == request

# Establish that the response is the type that we expect.
assert isinstance(response, future.Future)


@pytest.mark.parametrize(
"request_type",
[
vertex_rag_data_service.ListRagDataSchemasRequest,
dict,
],
)
def test_list_rag_data_schemas(request_type, transport: str = "grpc"):
client = VertexRagDataServiceClient(
credentials=ga_credentials.AnonymousCredentials(),
transport=transport,
)

# Everything is optional in proto3 as far as the runtime is concerned,
# and we are mocking out the actual API, so just send an empty request.
request = request_type()

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(
type(client.transport.list_rag_data_schemas), "__call__"
) as call:
# Designate an appropriate return value for the call.
call.return_value = vertex_rag_data_service.ListRagDataSchemasResponse(
next_page_token="next_page_token_value",
)
response = client.list_rag_data_schemas(request)

# Establish that the underlying gRPC stub method was called.
assert len(call.mock_calls) == 1
_, args, _ = call.mock_calls[0]
request = vertex_rag_data_service.ListRagDataSchemasRequest()
assert args[0] == request

# Establish that the response is the type that we expect.
assert isinstance(response, pagers.ListRagDataSchemasPager)
assert response.next_page_token == "next_page_token_value"


@pytest.mark.parametrize(
"request_type",
[
Expand Down
30 changes: 21 additions & 9 deletions tests/unit/vertex_rag/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,19 @@ def google_auth_mock():

@pytest.fixture
def authorized_session_mock():
with patch(
"google.auth.transport.requests.AuthorizedSession"
) as MockAuthorizedSession:
from google.auth.transport import requests

with mock.patch.object(requests, "AuthorizedSession") as MockAuthorizedSession:
mock_auth_session = MockAuthorizedSession(_TEST_CREDENTIALS)
yield mock_auth_session


@pytest.fixture
def rag_data_client_mock():
from vertexai.rag.utils import _gapic_utils

with mock.patch.object(
rag.utils._gapic_utils, "create_rag_data_service_client"
_gapic_utils, "create_rag_data_service_client"
) as rag_data_client_mock:
api_client_mock = mock.Mock(spec=VertexRagDataServiceClient)

Expand All @@ -84,8 +86,10 @@ def rag_data_client_mock():

@pytest.fixture
def rag_data_client_preview_mock():
from vertexai.preview.rag.utils import _gapic_utils

with mock.patch.object(
rag_preview.utils._gapic_utils, "create_rag_data_service_client"
_gapic_utils, "create_rag_data_service_client"
) as rag_data_client_mock:
api_client_mock = mock.Mock(spec=VertexRagDataServiceClientPreview)

Expand All @@ -108,8 +112,10 @@ def rag_data_client_preview_mock():

@pytest.fixture
def rag_data_client_mock_exception():
from vertexai.rag.utils import _gapic_utils

with mock.patch.object(
rag.utils._gapic_utils, "create_rag_data_service_client"
_gapic_utils, "create_rag_data_service_client"
) as rag_data_client_mock_exception:
api_client_mock = mock.Mock(spec=VertexRagDataServiceClient)
# create_rag_corpus
Expand Down Expand Up @@ -138,8 +144,10 @@ def rag_data_client_mock_exception():

@pytest.fixture
def rag_data_client_preview_mock_exception():
from vertexai.preview.rag.utils import _gapic_utils

with mock.patch.object(
rag_preview.utils._gapic_utils, "create_rag_data_service_client"
_gapic_utils, "create_rag_data_service_client"
) as rag_data_client_mock_exception:
api_client_mock = mock.Mock(spec=VertexRagDataServiceClientPreview)
# create_rag_corpus
Expand Down Expand Up @@ -172,8 +180,10 @@ def rag_data_client_preview_mock_exception():

@pytest.fixture
def rag_data_async_client_mock_exception():
from vertexai.rag.utils import _gapic_utils

with mock.patch.object(
rag.utils._gapic_utils, "create_rag_data_service_async_client"
_gapic_utils, "create_rag_data_service_async_client"
) as rag_data_async_client_mock_exception:
api_client_mock = mock.Mock(spec=VertexRagDataServiceAsyncClient)
# import_rag_files
Expand All @@ -184,8 +194,10 @@ def rag_data_async_client_mock_exception():

@pytest.fixture
def rag_data_async_client_preview_mock_exception():
from vertexai.preview.rag.utils import _gapic_utils

with mock.patch.object(
rag_preview.utils._gapic_utils, "create_rag_data_service_async_client"
_gapic_utils, "create_rag_data_service_async_client"
) as rag_data_async_client_mock_exception:
api_client_mock = mock.Mock(spec=VertexRagDataServiceAsyncClientPreview)
# import_rag_files
Expand Down
61 changes: 61 additions & 0 deletions tests/unit/vertex_rag/test_rag_constants_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,23 @@
ImportRagFilesRequest,
ImportRagFilesResponse,
JiraSource as GapicJiraSource,
MetadataValue as GapicMetadataValue,
RagContexts,
RagCorpus as GapicRagCorpus,
RagDataSchema as GapicRagDataSchema,
RagEngineConfig as GapicRagEngineConfig,
RagFileChunkingConfig,
RagFileParsingConfig,
RagFileTransformationConfig,
RagFile as GapicRagFile,
RagManagedDbConfig as GapicRagManagedDbConfig,
RagMetadataSchemaDetails as GapicRagMetadataSchemaDetails,
RagMetadata as GapicRagMetadata,
RagVectorDbConfig as GapicRagVectorDbConfig,
RetrieveContextsResponse,
SharePointSources as GapicSharePointSources,
SlackSource as GapicSlackSource,
UserSpecifiedMetadata as GapicUserSpecifiedMetadata,
VertexAiSearchConfig as GapicVertexAiSearchConfig,
)
from google.cloud.aiplatform_v1beta1.types import api_auth
Expand All @@ -54,15 +59,19 @@
LlmParserConfig,
LlmRanker,
MemoryCorpus,
MetadataValue,
Pinecone,
RagCorpus,
RagCorpusTypeConfig,
RagDataSchema,
RagEmbeddingModelConfig,
RagEngineConfig,
RagFile,
RagManagedDb,
RagManagedDbConfig,
RagManagedVertexVectorSearch,
RagMetadata,
RagMetadataSchemaDetails,
RagResource,
RagRetrievalConfig,
RagVectorDbConfig,
Expand All @@ -76,6 +85,7 @@
SlackChannelsSource,
Spanner,
Unprovisioned,
UserSpecifiedMetadata,
VertexAiSearchConfig,
VertexFeatureStore,
VertexPredictionEndpoint,
Expand Down Expand Up @@ -1146,3 +1156,54 @@
filter=Filter(vector_distance_threshold=0.5),
ranking=Ranking(llm_ranker=LlmRanker(model_name="test-model-name")),
)

# RagMetadata and RagDataSchema
TEST_RAG_DATA_SCHEMA_ID = "test-data-schema-id"
TEST_RAG_DATA_SCHEMA_RESOURCE_NAME = (
f"{TEST_RAG_CORPUS_RESOURCE_NAME}/ragDataSchemas/{TEST_RAG_DATA_SCHEMA_ID}"
)
TEST_RAG_METADATA_ID = "test-metadata-id"
TEST_RAG_METADATA_RESOURCE_NAME = (
f"{TEST_RAG_FILE_RESOURCE_NAME}/ragMetadata/{TEST_RAG_METADATA_ID}"
)

TEST_GAPIC_RAG_DATA_SCHEMA = GapicRagDataSchema(
name=TEST_RAG_DATA_SCHEMA_RESOURCE_NAME,
key="key1",
schema_details=GapicRagMetadataSchemaDetails(
type=GapicRagMetadataSchemaDetails.DataType.STRING,
search_strategy=GapicRagMetadataSchemaDetails.SearchStrategy(
search_strategy_type=GapicRagMetadataSchemaDetails.SearchStrategy.SearchStrategyType.EXACT_SEARCH
),
granularity=GapicRagMetadataSchemaDetails.Granularity.GRANULARITY_FILE_LEVEL,
),
)

TEST_RAG_DATA_SCHEMA = RagDataSchema(
name=TEST_RAG_DATA_SCHEMA_RESOURCE_NAME,
key="key1",
schema_details=RagMetadataSchemaDetails(
type="STRING",
search_strategy=RagMetadataSchemaDetails.SearchStrategy(
search_strategy_type="EXACT_SEARCH"
),
granularity="GRANULARITY_FILE_LEVEL",
),
)

TEST_GAPIC_RAG_METADATA = GapicRagMetadata(
name=TEST_RAG_METADATA_RESOURCE_NAME,
user_specified_metadata=GapicUserSpecifiedMetadata(
key="key1",
value=GapicMetadataValue(str_value="value1"),
),
)

TEST_RAG_METADATA = RagMetadata(
name=TEST_RAG_METADATA_RESOURCE_NAME,
user_specified_metadata=UserSpecifiedMetadata(
values={
"key1": MetadataValue(string_value="value1"),
}
),
)
Loading
Loading