Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2298,13 +2298,13 @@ def text(self) -> str:
) from e

@property
def function_calls(self) -> Sequence[gapic_tool_types.FunctionCall]:
def function_calls(self) -> Sequence["FunctionCall"]:
if not self.content or not self.content.parts:
return []
return [
part.function_call
for part in self.content.parts
if part and part.function_call
if part._raw_part._pb.WhichOneof("data") == "function_call"
]


Expand Down Expand Up @@ -2479,8 +2479,8 @@ def file_data(self) -> gapic_content_types.FileData:
return self._raw_part.file_data

@property
def function_call(self) -> gapic_tool_types.FunctionCall:
return self._raw_part.function_call
def function_call(self) -> "FunctionCall":
return FunctionCall._from_gapic(self._raw_part.function_call)

@property
def function_response(self) -> gapic_tool_types.FunctionResponse:
Expand All @@ -2491,6 +2491,35 @@ def _image(self) -> "Image":
return Image.from_bytes(data=self._raw_part.inline_data.data)


class FunctionCall:
"""Function call."""

def __init__(self):
self._raw_message = aiplatform_types.FunctionCall()

@classmethod
def _from_gapic(cls, raw_message: aiplatform_types.FunctionCall) -> "FunctionCall":
response = cls()
response._raw_message = raw_message
return response

def to_dict(self) -> Dict[str, Any]:
return _proto_to_dict(self._raw_message)

def __repr__(self) -> str:
return self._raw_message.__repr__()

@property
def name(self) -> str:
return self._raw_message.name

@property
def args(self) -> Dict[str, Any]:
# We cannot use `type(self.args).to_dict(self.args)`
# due to: AttributeError: type object 'MapComposite' has no attribute 'to_dict'
return self.to_dict().get("args")


class SafetySetting:
"""Parameters for the generation."""

Expand Down Expand Up @@ -2949,10 +2978,9 @@ def respond_to_model_response(
)

try:
# We cannot use `function_args = type(function_call.args).to_dict(function_call.args)`
# due to: AttributeError: type object 'MapComposite' has no attribute 'to_dict'
function_args = type(function_call).to_dict(function_call)["args"]
function_call_result = callable_function._function(**function_args)
function_call_result = callable_function._function(
**function_call.args
)
if not isinstance(function_call_result, Mapping):
# If the function returns a single value, wrap it in the
# format that Part.from_function_response can accept.
Expand Down