fix(Enum): support stdlib Enum conversion#562
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for Python standard library Enum types within the TVM FFI, allowing py_class fields to accept Enum members by automatically converting them to their underlying values. It includes updates to the type converter and schema generation logic, along with deduplication for Union types. Feedback suggests optimizing the schema comparison and deduplication logic in type_info.pxi by using direct TypeSchema equality checks instead of expensive json.dumps serialization.
| converted_list = [] | ||
| seen = set() | ||
| for arg in non_none: | ||
| schema = TypeSchema.from_annotation(arg) | ||
| key = json.dumps(schema.to_json(), sort_keys=True) | ||
| if key in seen: | ||
| continue | ||
| seen.add(key) | ||
| converted_list.append(schema) | ||
| converted = tuple(converted_list) |
There was a problem hiding this comment.
Using json.dumps to deduplicate schemas in a loop is inefficient due to string allocations and JSON processing. Since TypeSchema is a dataclass with a generated __eq__ method that performs structural comparison, you can deduplicate using a simple list membership check. Given that unions typically have very few arguments, this approach is significantly more efficient.
converted_list = []
for arg in non_none:
schema = TypeSchema.from_annotation(arg)
if schema not in converted_list:
converted_list.append(schema)
converted = tuple(converted_list)
| def _annotation_py_enum(cls): | ||
| """Map a homogeneous stdlib Enum subclass to its underlying value schema.""" | ||
| members = tuple(cls) | ||
| if not members: | ||
| raise TypeError(f"Enum subclass {cls!r} has no members") | ||
| first_schema = TypeSchema.from_annotation(type(members[0].value)) | ||
| first_key = json.dumps(first_schema.to_json(), sort_keys=True) | ||
| for member in members[1:]: | ||
| member_schema = TypeSchema.from_annotation(type(member.value)) | ||
| member_key = json.dumps(member_schema.to_json(), sort_keys=True) | ||
| if member_key != first_key: | ||
| raise TypeError( | ||
| f"Enum subclass {cls!r} has mixed value types and cannot be converted to TypeSchema" | ||
| ) | ||
| return first_schema |
There was a problem hiding this comment.
In _annotation_py_enum, comparing schemas using json.dumps is unnecessary and expensive. You can use the TypeSchema equality operator (==) directly, which performs a recursive structural comparison of the origin, args, and origin_type_index fields. This avoids repeated JSON serialization in the loop.
def _annotation_py_enum(cls):
"""Map a homogeneous stdlib Enum subclass to its underlying value schema."""
members = tuple(cls)
if not members:
raise TypeError(f"Enum subclass {cls!r} has no members")
first_schema = TypeSchema.from_annotation(type(members[0].value))
for member in members[1:]:
member_schema = TypeSchema.from_annotation(type(member.value))
if member_schema != first_schema:
raise TypeError(
f"Enum subclass {cls!r} has mixed value types and cannot be converted to TypeSchema"
)
return first_schema
No description provided.