diff --git a/src/opengradient/client/model_hub.py b/src/opengradient/client/model_hub.py index d8f5c92..f93f57e 100644 --- a/src/opengradient/client/model_hub.py +++ b/src/opengradient/client/model_hub.py @@ -1,6 +1,7 @@ """Model Hub for creating, versioning, and uploading ML models.""" import os +import time from typing import Dict, List, Optional import firebase # type: ignore[import-untyped] @@ -19,6 +20,9 @@ "databaseURL": os.getenv("FIREBASE_DATABASE_URL", ""), } +# Firebase idTokens expire after 3600 seconds; refresh 60 s before expiry +_TOKEN_REFRESH_MARGIN_SEC = 60 + class ModelHub: """ @@ -34,7 +38,14 @@ class ModelHub: """ def __init__(self, email: Optional[str] = None, password: Optional[str] = None): - self._hub_user = self._login(email, password) if email is not None else None + self._firebase_app = None + self._hub_user = None + self._token_expiry: float = 0.0 + + if email is not None: + self._firebase_app, self._hub_user = self._login(email, password) + expires_in = int(self._hub_user.get("expiresIn", 3600)) + self._token_expiry = time.time() + expires_in @staticmethod def _login(email: str, password: Optional[str]): @@ -42,7 +53,34 @@ def _login(email: str, password: Optional[str]): raise ValueError("Firebase API Key is missing in environment variables") firebase_app = firebase.initialize_app(_FIREBASE_CONFIG) - return firebase_app.auth().sign_in_with_email_and_password(email, password) + user = firebase_app.auth().sign_in_with_email_and_password(email, password) + return firebase_app, user + + def _get_auth_token(self) -> str: + """Return a valid Firebase idToken, refreshing it if it has expired or is + about to expire within ``_TOKEN_REFRESH_MARGIN_SEC`` seconds. + + Raises: + ValueError: If the user is not authenticated. + """ + if not self._hub_user: + raise ValueError("User not authenticated") + + if time.time() >= self._token_expiry - _TOKEN_REFRESH_MARGIN_SEC: + # Refresh the token using the stored refresh token + refresh_token = self._hub_user.get("refreshToken") + if not refresh_token or self._firebase_app is None: + raise ValueError( + "Cannot refresh Firebase token: missing refresh token or Firebase app. " + "Please re-authenticate by creating a new ModelHub instance." + ) + refreshed = self._firebase_app.auth().refresh(refresh_token) + self._hub_user["idToken"] = refreshed["idToken"] + self._hub_user["refreshToken"] = refreshed.get("refreshToken", refresh_token) + expires_in = int(refreshed.get("expiresIn", 3600)) + self._token_expiry = time.time() + expires_in + + return str(self._hub_user["idToken"]) # cast Any->str for mypy [no-any-return] def create_model(self, model_name: str, model_desc: str, version: str = "1.00") -> ModelRepository: """ @@ -51,19 +89,17 @@ def create_model(self, model_name: str, model_desc: str, version: str = "1.00") Args: model_name (str): The name of the model. model_desc (str): The description of the model. - version (str): The version identifier (default is "1.00"). + version (str): A label used in the initial version notes (default is "1.00"). + Note: the actual version string is assigned by the server. Returns: - dict: The server response containing model details. + ModelRepository: Object containing the model name and server-assigned version string. Raises: - CreateModelError: If the model creation fails. + RuntimeError: If the model creation fails. """ - if not self._hub_user: - raise ValueError("User not authenticated") - url = "https://api.opengradient.ai/api/v0/models/" - headers = {"Authorization": f"Bearer {self._hub_user['idToken']}", "Content-Type": "application/json"} + headers = {"Authorization": f"Bearer {self._get_auth_token()}", "Content-Type": "application/json"} payload = {"name": model_name, "description": model_desc} try: @@ -74,14 +110,18 @@ def create_model(self, model_name: str, model_desc: str, version: str = "1.00") raise RuntimeError(f"Model creation failed: {error_details}") from e json_response = response.json() - model_name = json_response.get("name") - if not model_name: + created_name = json_response.get("name") + if not created_name: raise Exception(f"Model creation response missing 'name'. Full response: {json_response}") - # Create the specified version for the newly created model - version_response = self.create_version(model_name, version) + # Create the initial version for the newly created model. + # Pass `version` as release notes (e.g. "1.00") since the server assigns + # its own version string — previously `version` was incorrectly passed as + # the positional `notes` argument, resulting in raw version labels as notes + # rather than the clearer "Initial version