Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions python/pyarrow/src/arrow/python/python_to_arrow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,14 @@ class PyConverter : public Converter<PyObject*, PyConversionOptions> {
}
};

// Helper function to unwrap extension scalar to its storage scalar
const Scalar& GetStorageScalar(const Scalar& scalar) {
if (scalar.type->id() == Type::EXTENSION) {
return *checked_cast<const ExtensionScalar&>(scalar).value;
}
return scalar;
}

template <typename T, typename Enable = void>
class PyPrimitiveConverter;

Expand Down Expand Up @@ -663,7 +671,8 @@ class PyPrimitiveConverter<
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
ARROW_RETURN_NOT_OK(
this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar)));
} else {
ARROW_ASSIGN_OR_RAISE(
auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
Expand All @@ -684,7 +693,8 @@ class PyPrimitiveConverter<
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
ARROW_RETURN_NOT_OK(
this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar)));
} else {
ARROW_ASSIGN_OR_RAISE(
auto converted, PyValue::Convert(this->primitive_type_, this->options_, value));
Expand All @@ -710,7 +720,8 @@ class PyPrimitiveConverter<T, enable_if_t<std::is_same<T, FixedSizeBinaryType>::
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
ARROW_RETURN_NOT_OK(
this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar)));
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->primitive_type_, this->options_, value, view_));
Expand Down Expand Up @@ -747,7 +758,8 @@ class PyPrimitiveConverter<
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
ARROW_RETURN_NOT_OK(this->primitive_builder_->AppendScalar(*scalar));
ARROW_RETURN_NOT_OK(
this->primitive_builder_->AppendScalar(GetStorageScalar(*scalar)));
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->primitive_type_, this->options_, value, view_));
Expand Down Expand Up @@ -791,7 +803,7 @@ class PyDictionaryConverter<U, enable_if_has_c_type<U>>
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->value_builder_->AppendScalar(*scalar, 1);
return this->value_builder_->AppendScalar(GetStorageScalar(*scalar), 1);
} else {
ARROW_ASSIGN_OR_RAISE(auto converted,
PyValue::Convert(this->value_type_, this->options_, value));
Expand All @@ -810,7 +822,7 @@ class PyDictionaryConverter<U, enable_if_has_string_view<U>>
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->value_builder_->AppendScalar(*scalar, 1);
return this->value_builder_->AppendScalar(GetStorageScalar(*scalar), 1);
} else {
ARROW_RETURN_NOT_OK(
PyValue::Convert(this->value_type_, this->options_, value, view_));
Expand Down Expand Up @@ -983,7 +995,7 @@ class PyStructConverter : public StructConverter<PyConverter, PyConverterTrait>
} else if (arrow::py::is_scalar(value)) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar,
arrow::py::unwrap_scalar(value));
return this->struct_builder_->AppendScalar(*scalar);
return this->struct_builder_->AppendScalar(GetStorageScalar(*scalar));
}
switch (input_kind_) {
case InputKind::DICT:
Expand Down
55 changes: 55 additions & 0 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import contextlib
import datetime
import os
import shutil
import subprocess
Expand Down Expand Up @@ -1486,6 +1487,60 @@ def bytes(self):
pa.scalar(bad)


def test_array_from_extension_scalars():
# One case per C++ converter: FixedSizeBinary, Binary/String
builtin_cases = [
(pa.uuid(), [b"0123456789abcdef"]),
(pa.opaque(pa.binary(), "t", "v"), [b"x", b"y"]),
]
for ext_type, values in builtin_cases:
scalars = [pa.scalar(v, type=ext_type) for v in values]
result = pa.array(scalars, type=ext_type)
assert result.equals(pa.array(values, type=ext_type))

# One case per C++ converter: Numeric, Timestamp/Duration, Struct
custom_cases = [
(IntegerType(), [100, 200]),
(AnnotatedType(pa.timestamp("us"), "ts"),
[datetime.datetime(2023, 1, 1)]),
(MyStructType(), [{"left": 1, "right": 2}]),
]
for ext_type, values in custom_cases:
with registered_extension_type(ext_type):
scalars = [pa.scalar(v, type=ext_type) for v in values]
result = pa.array(scalars, type=ext_type)
assert result.equals(pa.array(values, type=ext_type))

# Null handling
uuid_type = pa.uuid()
scalars = [pa.scalar(b"0123456789abcdef", type=uuid_type),
pa.scalar(None, type=uuid_type)]
result = pa.array(scalars, type=uuid_type)
assert result[0].is_valid and not result[1].is_valid

# ExtensionScalar.from_storage path
scalars = [
pa.ExtensionScalar.from_storage(uuid_type, b"0123456789abcdef"),
pa.ExtensionScalar.from_storage(uuid_type, None),
]
result = pa.array(scalars, type=uuid_type)
expected = pa.array([b"0123456789abcdef", None], type=uuid_type)
assert result.equals(expected)

# Type inference without explicit type
u = uuid4()
scalars = [pa.scalar(u, type=pa.uuid()), None]
result = pa.array(scalars)
assert result.type == pa.uuid()
assert result[0].as_py() == u
assert not result[1].is_valid

# Mixed extension scalars and raw Python objects
u1, u2 = uuid4(), uuid4()
result = pa.array([pa.scalar(u1, type=pa.uuid()), u2], type=pa.uuid())
assert result.equals(pa.array([u1, u2], type=pa.uuid()))


def test_tensor_type():
tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 3])
assert tensor_type.extension_name == "arrow.fixed_shape_tensor"
Expand Down
Loading