Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,19 @@ def execute(self, context: Context) -> dict[str, Any]:
nodes = splitter.get_nodes_from_documents(llama_docs)
self.log.info("Split %d documents into %d chunks", len(llama_docs), len(nodes))

# ``VectorStoreIndex(...)`` populates each node's ``.embedding`` as a
# side effect of building the index; capture the index so the
# variable isn't discarded.
# Pre-embed nodes before building the index. ``VectorStoreIndex``
# internally calls ``embed_nodes()`` which skips nodes whose
# ``.embedding`` is already set, so there are no duplicate API calls.
# This is necessary because ``VectorStoreIndex._get_node_with_embedding()``
# attaches embeddings to *copies* of the nodes (via ``model_copy()``),
# never the originals — so relying on the index to populate
# ``node.embedding`` as a side effect always yields ``None``.
texts = [node.get_content() for node in nodes]
embeddings = embed_model.get_text_embedding_batch(texts, show_progress=False)
for node, embedding in zip(nodes, embeddings):
node.embedding = embedding
self.log.info("Embedded %d chunks", len(nodes))

index = VectorStoreIndex(nodes, embed_model=embed_model, show_progress=False)

if self.persist_dir:
Expand All @@ -137,7 +147,7 @@ def execute(self, context: Context) -> dict[str, Any]:
# base ``get_nodes_from_documents`` signature is typed as
# ``list[BaseNode]`` (which has no ``.text``). Cast so mypy doesn't
# flag the ``.text`` access; ``node.embedding`` is populated by
# ``VectorStoreIndex`` for every node above.
# the pre-embed step above.
text_nodes = cast("list[TextNode]", nodes)
chunks = [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ def _node(text: str = "chunk text", metadata: dict | None = None, vector=None):
node.text = text
node.metadata = metadata or {}
node.embedding = vector
node.get_content.return_value = text
return node


def _byo_embedding():
"""Return a duck-typed ``BaseEmbedding`` stand-in (has the two methods the operator checks)."""
return MagicMock(name="MyBaseEmbedding", spec=["get_text_embedding", "_get_query_embedding"])
return MagicMock(name="MyBaseEmbedding", spec=["get_text_embedding", "_get_query_embedding", "get_text_embedding_batch"])


class TestEmbeddingOperatorInit:
Expand All @@ -71,8 +72,12 @@ class TestEmbeddingOperatorExecute:
@patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model")
def test_string_embed_model_goes_through_hook(self, mock_get_embed, _li):
# `embed_model` as a string -> hook builds OpenAIEmbedding.
mock_model = MagicMock()
mock_model.get_text_embedding_batch.return_value = [[0.1, 0.2]]
mock_get_embed.return_value = mock_model

_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [
_node(text="chunk a", vector=[0.1, 0.2]),
_node(text="chunk a"),
]

op = LlamaIndexEmbeddingOperator(
Expand All @@ -84,6 +89,7 @@ def test_string_embed_model_goes_through_hook(self, mock_get_embed, _li):
result = op.execute(context=MagicMock())

mock_get_embed.assert_called_once()
mock_model.get_text_embedding_batch.assert_called_once()
assert result["document_count"] == 1
assert result["chunk_count"] == 1
assert result["chunks"][0]["text"] == "chunk a"
Expand All @@ -92,6 +98,12 @@ def test_string_embed_model_goes_through_hook(self, mock_get_embed, _li):
@patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook")
def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _li):
# ``embed_conn_id`` overrides ``llm_conn_id`` for the embedding API.
mock_hook_instance = MagicMock()
mock_model = MagicMock()
mock_model.get_text_embedding_batch.return_value = [[0.0]]
mock_hook_instance.get_embedding_model.return_value = mock_model
mock_hook_cls.return_value = mock_hook_instance

_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()]

op = LlamaIndexEmbeddingOperator(
Expand All @@ -112,6 +124,7 @@ def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _li):
def test_byo_embed_model_bypasses_hook(self, _li):
# `embed_model` is a non-string instance -> hook is bypassed.
byo = _byo_embedding()
byo.get_text_embedding_batch.return_value = [[0.0]]
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()]

op = LlamaIndexEmbeddingOperator(
Expand Down Expand Up @@ -142,9 +155,13 @@ def test_invalid_embed_model_raises_typeerror(self, _li):

@patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model")
def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _li):
mock_model = MagicMock()
mock_model.get_text_embedding_batch.return_value = [[1.0, 2.0], [3.0, 4.0]]
mock_get_embed.return_value = mock_model

_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [
_node(text="x", metadata={"k": "v"}, vector=[1.0, 2.0]),
_node(text="y", metadata={"k": "v2"}, vector=[3.0, 4.0]),
_node(text="x", metadata={"k": "v"}),
_node(text="y", metadata={"k": "v2"}),
]

op = LlamaIndexEmbeddingOperator(
Expand All @@ -159,13 +176,41 @@ def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _li):
{"text": "y", "metadata": {"k": "v2"}, "vector": [3.0, 4.0]},
]

@patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model")
def test_pre_embed_called_with_node_texts(self, mock_get_embed, _li):
# Verify that get_text_embedding_batch is called with the correct texts
mock_model = MagicMock()
mock_model.get_text_embedding_batch.return_value = [[0.1], [0.2], [0.3]]
mock_get_embed.return_value = mock_model

_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [
_node(text="first chunk"),
_node(text="second chunk"),
_node(text="third chunk"),
]

op = LlamaIndexEmbeddingOperator(
task_id="test",
documents=[{"text": "doc"}],
embed_model="text-embedding-3-small",
)
op.execute(context=MagicMock())

# get_text_embedding_batch should be called with the node texts
call_args = mock_model.get_text_embedding_batch.call_args
assert call_args[0][0] == ["first chunk", "second chunk", "third chunk"]


class TestEmbeddingOperatorPersist:
@patch("os.makedirs")
@patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model")
def test_local_persist_dir_calls_makedirs_and_storage_persist(
self, mock_get_embed, mock_makedirs, _li, tmp_path
):
mock_model = MagicMock()
mock_model.get_text_embedding_batch.return_value = [[0.0]]
mock_get_embed.return_value = mock_model

node = _node()
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node]
index = _li["VectorStoreIndex"].return_value
Expand Down Expand Up @@ -193,6 +238,10 @@ def test_cloud_uri_persist_dir_uses_object_storage_path(self, mock_get_embed, mo
target.fs = MagicMock(name="s3fs")
mock_osp_cls.return_value = target

mock_model = MagicMock()
mock_model.get_text_embedding_batch.return_value = [[0.0]]
mock_get_embed.return_value = mock_model

node = _node()
_li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node]
index = _li["VectorStoreIndex"].return_value
Expand Down
Loading