diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py index d85e692100202..78b01da1f4de4 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llamaindex_embedding.py @@ -125,9 +125,19 @@ def execute(self, context: Context) -> dict[str, Any]: nodes = splitter.get_nodes_from_documents(llama_docs) self.log.info("Split %d documents into %d chunks", len(llama_docs), len(nodes)) - # ``VectorStoreIndex(...)`` populates each node's ``.embedding`` as a - # side effect of building the index; capture the index so the - # variable isn't discarded. + # Pre-embed nodes before building the index. ``VectorStoreIndex`` + # internally calls ``embed_nodes()`` which skips nodes whose + # ``.embedding`` is already set, so there are no duplicate API calls. + # This is necessary because ``VectorStoreIndex._get_node_with_embedding()`` + # attaches embeddings to *copies* of the nodes (via ``model_copy()``), + # never the originals — so relying on the index to populate + # ``node.embedding`` as a side effect always yields ``None``. + texts = [node.get_content() for node in nodes] + embeddings = embed_model.get_text_embedding_batch(texts, show_progress=False) + for node, embedding in zip(nodes, embeddings): + node.embedding = embedding + self.log.info("Embedded %d chunks", len(nodes)) + index = VectorStoreIndex(nodes, embed_model=embed_model, show_progress=False) if self.persist_dir: @@ -137,7 +147,7 @@ def execute(self, context: Context) -> dict[str, Any]: # base ``get_nodes_from_documents`` signature is typed as # ``list[BaseNode]`` (which has no ``.text``). Cast so mypy doesn't # flag the ``.text`` access; ``node.embedding`` is populated by - # ``VectorStoreIndex`` for every node above. + # the pre-embed step above. text_nodes = cast("list[TextNode]", nodes) chunks = [ { diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py index 43b44f87c9ff4..c8faed3f6f0ae 100644 --- a/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py +++ b/providers/common/ai/tests/unit/common/ai/operators/test_llamaindex_embedding.py @@ -43,12 +43,13 @@ def _node(text: str = "chunk text", metadata: dict | None = None, vector=None): node.text = text node.metadata = metadata or {} node.embedding = vector + node.get_content.return_value = text return node def _byo_embedding(): """Return a duck-typed ``BaseEmbedding`` stand-in (has the two methods the operator checks).""" - return MagicMock(name="MyBaseEmbedding", spec=["get_text_embedding", "_get_query_embedding"]) + return MagicMock(name="MyBaseEmbedding", spec=["get_text_embedding", "_get_query_embedding", "get_text_embedding_batch"]) class TestEmbeddingOperatorInit: @@ -71,8 +72,12 @@ class TestEmbeddingOperatorExecute: @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") def test_string_embed_model_goes_through_hook(self, mock_get_embed, _li): # `embed_model` as a string -> hook builds OpenAIEmbedding. + mock_model = MagicMock() + mock_model.get_text_embedding_batch.return_value = [[0.1, 0.2]] + mock_get_embed.return_value = mock_model + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ - _node(text="chunk a", vector=[0.1, 0.2]), + _node(text="chunk a"), ] op = LlamaIndexEmbeddingOperator( @@ -84,6 +89,7 @@ def test_string_embed_model_goes_through_hook(self, mock_get_embed, _li): result = op.execute(context=MagicMock()) mock_get_embed.assert_called_once() + mock_model.get_text_embedding_batch.assert_called_once() assert result["document_count"] == 1 assert result["chunk_count"] == 1 assert result["chunks"][0]["text"] == "chunk a" @@ -92,6 +98,12 @@ def test_string_embed_model_goes_through_hook(self, mock_get_embed, _li): @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook") def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _li): # ``embed_conn_id`` overrides ``llm_conn_id`` for the embedding API. + mock_hook_instance = MagicMock() + mock_model = MagicMock() + mock_model.get_text_embedding_batch.return_value = [[0.0]] + mock_hook_instance.get_embedding_model.return_value = mock_model + mock_hook_cls.return_value = mock_hook_instance + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] op = LlamaIndexEmbeddingOperator( @@ -112,6 +124,7 @@ def test_string_embed_model_forwards_embed_conn_id(self, mock_hook_cls, _li): def test_byo_embed_model_bypasses_hook(self, _li): # `embed_model` is a non-string instance -> hook is bypassed. byo = _byo_embedding() + byo.get_text_embedding_batch.return_value = [[0.0]] _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [_node()] op = LlamaIndexEmbeddingOperator( @@ -142,9 +155,13 @@ def test_invalid_embed_model_raises_typeerror(self, _li): @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _li): + mock_model = MagicMock() + mock_model.get_text_embedding_batch.return_value = [[1.0, 2.0], [3.0, 4.0]] + mock_get_embed.return_value = mock_model + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ - _node(text="x", metadata={"k": "v"}, vector=[1.0, 2.0]), - _node(text="y", metadata={"k": "v2"}, vector=[3.0, 4.0]), + _node(text="x", metadata={"k": "v"}), + _node(text="y", metadata={"k": "v2"}), ] op = LlamaIndexEmbeddingOperator( @@ -159,6 +176,30 @@ def test_chunks_carry_text_metadata_vector(self, mock_get_embed, _li): {"text": "y", "metadata": {"k": "v2"}, "vector": [3.0, 4.0]}, ] + @patch("airflow.providers.common.ai.hooks.llamaindex.LlamaIndexHook.get_embedding_model") + def test_pre_embed_called_with_node_texts(self, mock_get_embed, _li): + # Verify that get_text_embedding_batch is called with the correct texts + mock_model = MagicMock() + mock_model.get_text_embedding_batch.return_value = [[0.1], [0.2], [0.3]] + mock_get_embed.return_value = mock_model + + _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [ + _node(text="first chunk"), + _node(text="second chunk"), + _node(text="third chunk"), + ] + + op = LlamaIndexEmbeddingOperator( + task_id="test", + documents=[{"text": "doc"}], + embed_model="text-embedding-3-small", + ) + op.execute(context=MagicMock()) + + # get_text_embedding_batch should be called with the node texts + call_args = mock_model.get_text_embedding_batch.call_args + assert call_args[0][0] == ["first chunk", "second chunk", "third chunk"] + class TestEmbeddingOperatorPersist: @patch("os.makedirs") @@ -166,6 +207,10 @@ class TestEmbeddingOperatorPersist: def test_local_persist_dir_calls_makedirs_and_storage_persist( self, mock_get_embed, mock_makedirs, _li, tmp_path ): + mock_model = MagicMock() + mock_model.get_text_embedding_batch.return_value = [[0.0]] + mock_get_embed.return_value = mock_model + node = _node() _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] index = _li["VectorStoreIndex"].return_value @@ -193,6 +238,10 @@ def test_cloud_uri_persist_dir_uses_object_storage_path(self, mock_get_embed, mo target.fs = MagicMock(name="s3fs") mock_osp_cls.return_value = target + mock_model = MagicMock() + mock_model.get_text_embedding_batch.return_value = [[0.0]] + mock_get_embed.return_value = mock_model + node = _node() _li["SentenceSplitter"].return_value.get_nodes_from_documents.return_value = [node] index = _li["VectorStoreIndex"].return_value