Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion adk/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def create_exception(exception, loading_exception=False):
response = json.dumps({
"error": {
"message": str(exception),
"stacktrace": traceback.format_exc(),
"stacktrace": " ".join(traceback.format_exception(etype=type(exception), value=exception, tb=exception.__traceback__)),
"error_type": error_type,
}
})
Expand Down
10 changes: 10 additions & 0 deletions adk/mlops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ def __init__(self, api_token, path):
if not os.path.exists(self.agent_dir):
raise Exception("environment is not configured for mlops.\nPlease select a valid mlops enabled environment.")

if self.endpoint is None:
raise Exception("'no endpoint found, please add 'MLOPS_SERVICE_URL' environment variable, or create an "
"mlops.json file")
if self.model_id is None:
raise Exception("no model_id found, please add 'MODEL_ID' environment variable, or create an mlops.json "
"file")
if self.deployment_id is None:
raise Exception("no deployment_id found, please add 'DEPLOYMENT_ID' environment variable, or create an "
"mlops.json file")

def init(self):
os.environ['MLOPS_DEPLOYMENT_ID'] = self.deployment_id
os.environ['MLOPS_MODEL_ID'] = self.model_id
Expand Down
55 changes: 30 additions & 25 deletions adk/modeldata.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,37 +56,42 @@ def initialize(self):
self.models[name] = FileData(real_hash, local_data_path)

def get_model(self, model_name):
if model_name in self.models:
return self.models[model_name].file_path
elif len([optional for optional in self.manifest_data['optional_files'] if
optional['name'] == model_name]) > 0:
self.find_optional_model(model_name)
return self.models[model_name].file_path
if self.available():
if model_name in self.models:
return self.models[model_name].file_path
elif len([optional for optional in self.manifest_data['optional_files'] if
optional['name'] == model_name]) > 0:
self.find_optional_model(model_name)
return self.models[model_name].file_path
else:
raise Exception("model name " + model_name + " not found in manifest")
else:
raise Exception("model name " + model_name + " not found in manifest")
raise Exception("unable to get model {}, model_manifest.json not found.".format(model_name))

def find_optional_model(self, file_name):

found_models = [optional for optional in self.manifest_data['optional_files'] if
optional['name'] == file_name]
if len(found_models) == 0:
raise Exception("file with name '" + file_name + "' not found in model manifest.")
model_info = found_models[0]
self.models[file_name] = {}
source_uri = model_info['source_uri']
fail_on_tamper = model_info.get("fail_on_tamper", False)
expected_hash = model_info.get('md5_checksum', None)
with self.client.file(source_uri).getFile() as f:
local_data_path = f.name
real_hash = md5_for_file(local_data_path)
if self.using_frozen:
if real_hash != expected_hash and fail_on_tamper:
raise Exception("Model File Mismatch for " + file_name +
"\nexpected hash: " + expected_hash + "\nreal hash: " + real_hash)
if self.available():
found_models = [optional for optional in self.manifest_data['optional_files'] if
optional['name'] == file_name]
if len(found_models) == 0:
raise Exception("file with name '" + file_name + "' not found in model manifest.")
model_info = found_models[0]
self.models[file_name] = {}
source_uri = model_info['source_uri']
fail_on_tamper = model_info.get("fail_on_tamper", False)
expected_hash = model_info.get('md5_checksum', None)
with self.client.file(source_uri).getFile() as f:
local_data_path = f.name
real_hash = md5_for_file(local_data_path)
if self.using_frozen:
if real_hash != expected_hash and fail_on_tamper:
raise Exception("Model File Mismatch for " + file_name +
"\nexpected hash: " + expected_hash + "\nreal hash: " + real_hash)
else:
self.models[file_name] = FileData(real_hash, local_data_path)
else:
self.models[file_name] = FileData(real_hash, local_data_path)
else:
self.models[file_name] = FileData(real_hash, local_data_path)
raise Exception("unable to get model {}, model_manifest.json not found.".format(model_name))

def get_manifest(self):
if os.path.exists(self.manifest_frozen_path):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_adk_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,18 @@ def test_manifest_file_success(self):
self.assertEqual(expected_output, actual_output)

def test_manifest_file_tampered(self):
input = "Algorithmia"
input = 'Algorithmia'
expected_output = {"error": {"error_type": "LoadingError",
"message": "Model File Mismatch for squeezenet\n"
"expected hash: f20b50b44fdef367a225d41f747a0963\n"
"real hash: 46a44d32d2c5c07f7f66324bef4c7266",
"stacktrace": "NoneType: None\n"}}
"stacktrace": ''}}

actual_output = json.loads(self.execute_manifest_example(input, apply_successful_manifest_parsing,
loading_with_manifest,
manifest_path="tests/manifests/bad_model_manifest"
".json"))
actual_output['error']['stacktrace'] = ''
self.assertEqual(expected_output, actual_output)


Expand Down