Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exclude classes from inference using pretrained or custom models #1104

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,24 @@
logger = logging.getLogger(__name__)


def filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id):
return [
obj_pred
for obj_pred in object_prediction_list
if obj_pred.category.name not in (exclude_classes_by_name or [])
and obj_pred.category.id not in (exclude_classes_by_id or [])
]


def get_prediction(
image,
detection_model,
shift_amount: list = [0, 0],
full_shape=None,
postprocess: Optional[PostprocessPredictions] = None,
verbose: int = 0,
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
) -> PredictionResult:
"""
Function for performing prediction for given image using given detection_model.
Expand All @@ -78,6 +89,12 @@ def get_prediction(
verbose: int
0: no print (default)
1: print prediction duration
exclude_classes_by_name: Optional[List[str]]
None: if no classes are excluded
List[str]: set of classes to exclude using its/their class label name/s
exclude_classes_by_id: Optional[List[int]]
None: if no classes are excluded
List[str]: set of classes to exclude using one or more IDs

Returns:
A dict with fields:
Expand All @@ -102,6 +119,7 @@ def get_prediction(
full_shape=full_shape,
)
object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list
object_prediction_list = filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id)

# postprocess matching predictions
if postprocess is not None:
Expand Down Expand Up @@ -139,6 +157,8 @@ def get_sliced_prediction(
auto_slice_resolution: bool = True,
slice_export_prefix: str = None,
slice_dir: str = None,
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
) -> PredictionResult:
"""
Function for slice image + get predicion for each slice + combine predictions in full image.
Expand Down Expand Up @@ -188,6 +208,12 @@ def get_sliced_prediction(
Prefix for the exported slices. Defaults to None.
slice_dir: str
Directory to save the slices. Defaults to None.
exclude_classes_by_name: Optional[List[str]]
None: if no classes are excluded
List[str]: set of classes to exclude using its/their class label name/s
exclude_classes_by_id: Optional[List[int]]
None: if no classes are excluded
List[str]: set of classes to exclude using one or more IDs

Returns:
A Dict with fields:
Expand Down Expand Up @@ -254,6 +280,8 @@ def get_sliced_prediction(
slice_image_result.original_image_height,
slice_image_result.original_image_width,
],
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
# convert sliced predictions to full predictions
for object_prediction in prediction_result.object_prediction_list:
Expand All @@ -275,6 +303,8 @@ def get_sliced_prediction(
slice_image_result.original_image_width,
],
postprocess=None,
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
object_prediction_list.extend(prediction_result.object_prediction_list)

Expand Down Expand Up @@ -377,6 +407,8 @@ def predict(
verbose: int = 1,
return_dict: bool = False,
force_postprocess_type: bool = False,
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -463,6 +495,12 @@ def predict(
If True, returns a dict with 'export_dir' field.
force_postprocess_type: bool
If True, auto postprocess check will e disabled
exclude_classes_by_name: Optional[List[str]]
None: if no classes are excluded
List[str]: set of classes to exclude using its/their class label name/s
exclude_classes_by_id: Optional[List[int]]
None: if no classes are excluded
List[str]: set of classes to exclude using one or more IDs
"""
# assert prediction type
if no_standard_prediction and no_sliced_prediction:
Expand Down Expand Up @@ -569,6 +607,8 @@ def predict(
postprocess_match_threshold=postprocess_match_threshold,
postprocess_class_agnostic=postprocess_class_agnostic,
verbose=1 if verbose else 0,
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
object_prediction_list = prediction_result.object_prediction_list
durations_in_seconds["slice"] += prediction_result.durations_in_seconds["slice"]
Expand All @@ -581,6 +621,8 @@ def predict(
full_shape=None,
postprocess=None,
verbose=0,
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
object_prediction_list = prediction_result.object_prediction_list

Expand Down Expand Up @@ -745,6 +787,8 @@ def predict_fiftyone(
postprocess_match_threshold: float = 0.5,
postprocess_class_agnostic: bool = False,
verbose: int = 1,
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
):
"""
Performs prediction for all present images in given folder.
Expand Down Expand Up @@ -803,6 +847,12 @@ def predict_fiftyone(
verbose: int
0: no print
1: print slice/prediction durations, number of slices, model loading/file exporting durations
exclude_classes_by_name: Optional[List[str]]
None: if no classes are excluded
List[str]: set of classes to exclude using its/their class label name/s
exclude_classes_by_id: Optional[List[int]]
None: if no classes are excluded
List[str]: set of classes to exclude using one or more IDs
"""
check_requirements(["fiftyone"])

Expand Down Expand Up @@ -855,6 +905,8 @@ def predict_fiftyone(
postprocess_match_metric=postprocess_match_metric,
postprocess_class_agnostic=postprocess_class_agnostic,
verbose=verbose,
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
durations_in_seconds["slice"] += prediction_result.durations_in_seconds["slice"]
else:
Expand All @@ -866,6 +918,8 @@ def predict_fiftyone(
full_shape=None,
postprocess=None,
verbose=0,
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
durations_in_seconds["prediction"] += prediction_result.durations_in_seconds["prediction"]

Expand Down
87 changes: 87 additions & 0 deletions tests/test_exclude_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from sahi import AutoDetectionModel
from sahi.predict import get_prediction, get_sliced_prediction, predict
from sahi.utils.file import download_from_url
from sahi.utils.yolov8 import download_yolov8s_model

# 1. Download the YOLOv8 model weights
yolov8_model_path = "models/yolov8s.pt"
download_yolov8s_model(yolov8_model_path)

# 2. Download sample test images
download_from_url(
"https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/small-vehicles1.jpeg",
"demo_data/small-vehicles1.jpeg",
)
download_from_url(
"https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/terrain2.png",
"demo_data/terrain2.png",
)

# 3. Load the YOLOv8 detection model
detection_model = AutoDetectionModel.from_pretrained(
model_type="yolov8", # Model type (YOLOv8 in this case)
model_path=yolov8_model_path, # Path to model weights
confidence_threshold=0.5, # Confidence threshold for predictions
device="cpu", # Use "cuda" for GPU inference
)

# 4. Define the classes to exclude
exclude_classes_by_name = ["car"]

# 5. Demonstrate `get_prediction` with class exclusion
print("===== Testing `get_prediction` =====")
result = get_prediction(
image="demo_data/small-vehicles1.jpeg",
detection_model=detection_model,
shift_amount=[0, 0], # No shift applied
full_shape=None, # Full image shape is not provided
postprocess=None, # Postprocess disabled
verbose=1, # Enable verbose output
exclude_classes_by_name=exclude_classes_by_name, # Exclude 'car'
)

print("\nFiltered Results from `get_prediction` (First 5 Predictions):")
for obj in result.object_prediction_list[:5]:
print(f"Class ID: {obj.category.id}, Class Name: {obj.category.name}, Score: {obj.score}")

# 6. Demonstrate `get_sliced_prediction` with and without filtering
print("\n===== Testing `get_sliced_prediction` (Without Filtering) =====")
result = get_sliced_prediction(
image="demo_data/small-vehicles1.jpeg",
detection_model=detection_model,
slice_height=256, # Slice height
slice_width=256, # Slice width
overlap_height_ratio=0.2, # Overlap height ratio
overlap_width_ratio=0.2, # Overlap width ratio
verbose=1, # Enable verbose output
)
print("\nNon-Filtered Results from `get_sliced_prediction` (First 5 Predictions):")
for obj in result.object_prediction_list[:5]:
print(f"Class ID: {obj.category.id}, Class Name: {obj.category.name}, Score: {obj.score}")

print("\n===== Testing `get_sliced_prediction` (With Filtering) =====")
result = get_sliced_prediction(
image="demo_data/small-vehicles1.jpeg",
detection_model=detection_model,
slice_height=256,
slice_width=256,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2,
verbose=1,
exclude_classes_by_name=exclude_classes_by_name, # Exclude 'car'
)
print("\nFiltered Results from `get_sliced_prediction` (First 5 Predictions):")
for obj in result.object_prediction_list[:5]:
print(f"Class ID: {obj.category.id}, Class Name: {obj.category.name}, Score: {obj.score}")

# 7. Demonstrate `predict` with filtering for a single image
print("\n===== Testing `predict` =====")
predict(
detection_model=detection_model,
source="demo_data/small-vehicles1.jpeg", # Single image source
project="runs/test_predict", # Output project directory
name="exclude_test", # Run name
verbose=1, # Enable verbose output
exclude_classes_by_name=exclude_classes_by_name, # Exclude 'car'
)
print("\nFiltered results from `predict` saved in 'runs/test_predict/exclude_test'")