diff --git a/README.md b/README.md index b27d3ac..e8a61ed 100644 --- a/README.md +++ b/README.md @@ -11,22 +11,22 @@ conda create -n samapi -y python=3.11 conda activate samapi ``` -Install `cudatoolkit`. +If you're using a computer with CUDA-compatible GPU, install `cudatoolkit`. ```bash conda install -y cudatoolkit=11.3 ``` -Install `samapi` and its dependencies. +If you are using WSL2, `LD_LIBRARY_PATH` will need to be updated as follows. ```bash -python -m pip install git+https://github.com/ksugar/samapi.git +export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH ``` -If you are using WSL2, `LD_LIBRARY_PATH` will need to be updated as follows. +Install `samapi` and its dependencies. ```bash -export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH +python -m pip install git+https://github.com/ksugar/samapi.git ``` ## Usage diff --git a/src/samapi/main.py b/src/samapi/main.py index f0e2bb0..168b439 100644 --- a/src/samapi/main.py +++ b/src/samapi/main.py @@ -1,3 +1,4 @@ +import warnings from enum import Enum from typing import Optional, Tuple @@ -8,6 +9,7 @@ from pydantic import Field from segment_anything import sam_model_registry, SamPredictor from torch.hub import load_state_dict_from_url +import torch from samapi.utils import decode_image, mask_to_geometry @@ -32,14 +34,26 @@ class ModelType(str, Enum): ), } +def _get_device() -> str: + """ + Selects the device to use for inference, based on what is available. + :return: + """ + if torch.cuda.is_available(): + return "cuda" + elif torch.backends.mps.is_built() and torch.backends.mps.is_built(): + return "mps" + else: + warnings.warn("No GPU support found - using CPU for inference") + return "cpu" def get_sam_model(model_type: ModelType): sam = sam_model_registry[model_type]() sam.load_state_dict(SAM_CHECKPOINTS[model_type]) return sam - -predictor = SamPredictor(get_sam_model(ModelType.vit_h).to(device="cuda")) +device = _get_device() +predictor = SamPredictor(get_sam_model(ModelType.vit_h).to(device=device)) sam_type = ModelType.vit_h @@ -54,7 +68,7 @@ async def predict_sam(body: SAMBody): global sam_type global predictor if body.type != sam_type: - predictor = SamPredictor(get_sam_model(body.type).to(device="cuda")) + predictor = SamPredictor(get_sam_model(body.type).to(device=device)) sam_type = body.type image = decode_image(body.b64img) if image.ndim == 2: @@ -75,4 +89,4 @@ async def predict_sam(body: SAMBody): properties={"object_idx": index_number, "label": "object"}, ) ) - return features + return features \ No newline at end of file