Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@ conda create -n samapi -y python=3.11
conda activate samapi
```

Install `cudatoolkit`.
If you're using a computer with CUDA-compatible GPU, install `cudatoolkit`.

```bash
conda install -y cudatoolkit=11.3
```

Install `samapi` and its dependencies.
If you are using WSL2, `LD_LIBRARY_PATH` will need to be updated as follows.

```bash
python -m pip install git+https://github.com/ksugar/samapi.git
export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH
```

If you are using WSL2, `LD_LIBRARY_PATH` will need to be updated as follows.
Install `samapi` and its dependencies.

```bash
export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH
python -m pip install git+https://github.com/ksugar/samapi.git
```

## Usage
Expand Down
22 changes: 18 additions & 4 deletions src/samapi/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from enum import Enum
from typing import Optional, Tuple

Expand All @@ -8,6 +9,7 @@
from pydantic import Field
from segment_anything import sam_model_registry, SamPredictor
from torch.hub import load_state_dict_from_url
import torch

from samapi.utils import decode_image, mask_to_geometry

Expand All @@ -32,14 +34,26 @@ class ModelType(str, Enum):
),
}

def _get_device() -> str:
"""
Selects the device to use for inference, based on what is available.
:return:
"""
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_built() and torch.backends.mps.is_built():
return "mps"
else:
warnings.warn("No GPU support found - using CPU for inference")
return "cpu"

def get_sam_model(model_type: ModelType):
sam = sam_model_registry[model_type]()
sam.load_state_dict(SAM_CHECKPOINTS[model_type])
return sam


predictor = SamPredictor(get_sam_model(ModelType.vit_h).to(device="cuda"))
device = _get_device()
predictor = SamPredictor(get_sam_model(ModelType.vit_h).to(device=device))
sam_type = ModelType.vit_h


Expand All @@ -54,7 +68,7 @@ async def predict_sam(body: SAMBody):
global sam_type
global predictor
if body.type != sam_type:
predictor = SamPredictor(get_sam_model(body.type).to(device="cuda"))
predictor = SamPredictor(get_sam_model(body.type).to(device=device))
sam_type = body.type
image = decode_image(body.b64img)
if image.ndim == 2:
Expand All @@ -75,4 +89,4 @@ async def predict_sam(body: SAMBody):
properties={"object_idx": index_number, "label": "object"},
)
)
return features
return features