Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exclude classes from inference using pretrained or custom models #1104

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@

logger = logging.getLogger(__name__)

def filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id):
return [
obj_pred for obj_pred in object_prediction_list
if obj_pred.category.name not in (exclude_classes_by_name or [])
and obj_pred.category.id not in (exclude_classes_by_id or [])
]


def get_prediction(
image,
Expand All @@ -61,6 +68,8 @@ def get_prediction(
full_shape=None,
postprocess: Optional[PostprocessPredictions] = None,
verbose: int = 0,
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
) -> PredictionResult:
"""
Function for performing prediction for given image using given detection_model.
Expand Down Expand Up @@ -102,6 +111,11 @@ def get_prediction(
full_shape=full_shape,
)
object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list
object_prediction_list = filter_predictions(
object_prediction_list,
exclude_classes_by_name,
exclude_classes_by_id
)

# postprocess matching predictions
if postprocess is not None:
Expand Down Expand Up @@ -139,6 +153,8 @@ def get_sliced_prediction(
auto_slice_resolution: bool = True,
slice_export_prefix: str = None,
slice_dir: str = None,
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
) -> PredictionResult:
"""
Function for slice image + get predicion for each slice + combine predictions in full image.
Expand Down Expand Up @@ -254,6 +270,8 @@ def get_sliced_prediction(
slice_image_result.original_image_height,
slice_image_result.original_image_width,
],
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
# convert sliced predictions to full predictions
for object_prediction in prediction_result.object_prediction_list:
Expand All @@ -275,6 +293,8 @@ def get_sliced_prediction(
slice_image_result.original_image_width,
],
postprocess=None,
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
object_prediction_list.extend(prediction_result.object_prediction_list)

Expand Down Expand Up @@ -377,6 +397,8 @@ def predict(
verbose: int = 1,
return_dict: bool = False,
force_postprocess_type: bool = False,
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -569,6 +591,8 @@ def predict(
postprocess_match_threshold=postprocess_match_threshold,
postprocess_class_agnostic=postprocess_class_agnostic,
verbose=1 if verbose else 0,
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
object_prediction_list = prediction_result.object_prediction_list
durations_in_seconds["slice"] += prediction_result.durations_in_seconds["slice"]
Expand All @@ -581,6 +605,8 @@ def predict(
full_shape=None,
postprocess=None,
verbose=0,
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
object_prediction_list = prediction_result.object_prediction_list

Expand Down Expand Up @@ -745,6 +771,8 @@ def predict_fiftyone(
postprocess_match_threshold: float = 0.5,
postprocess_class_agnostic: bool = False,
verbose: int = 1,
exclude_classes_by_name: Optional[List[str]] = None,
exclude_classes_by_id: Optional[List[int]] = None,
):
"""
Performs prediction for all present images in given folder.
Expand Down Expand Up @@ -855,6 +883,8 @@ def predict_fiftyone(
postprocess_match_metric=postprocess_match_metric,
postprocess_class_agnostic=postprocess_class_agnostic,
verbose=verbose,
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
durations_in_seconds["slice"] += prediction_result.durations_in_seconds["slice"]
else:
Expand All @@ -866,6 +896,8 @@ def predict_fiftyone(
full_shape=None,
postprocess=None,
verbose=0,
exclude_classes_by_name=exclude_classes_by_name,
exclude_classes_by_id=exclude_classes_by_id,
)
durations_in_seconds["prediction"] += prediction_result.durations_in_seconds["prediction"]

Expand Down Expand Up @@ -912,4 +944,4 @@ def predict_fiftyone(
# Show samples with most false positives
session.view = eval_view.sort_by("eval_fp", reverse=True)
while 1:
time.sleep(3)
time.sleep(3)
87 changes: 87 additions & 0 deletions tests/test_exclude_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from sahi.utils.file import download_from_url
from sahi.utils.yolov8 import download_yolov8s_model
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction, get_prediction, predict

# 1. Download the YOLOv8 model weights
yolov8_model_path = "models/yolov8s.pt"
download_yolov8s_model(yolov8_model_path)

# 2. Download sample test images
download_from_url(
"https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/small-vehicles1.jpeg",
"demo_data/small-vehicles1.jpeg",
)
download_from_url(
"https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/terrain2.png",
"demo_data/terrain2.png",
)

# 3. Load the YOLOv8 detection model
detection_model = AutoDetectionModel.from_pretrained(
model_type="yolov8", # Model type (YOLOv8 in this case)
model_path=yolov8_model_path, # Path to model weights
confidence_threshold=0.5, # Confidence threshold for predictions
device="cpu", # Use "cuda" for GPU inference
)

# 4. Define the classes to exclude
exclude_classes_by_name = ["car"]

# 5. Demonstrate `get_prediction` with class exclusion
print("===== Testing `get_prediction` =====")
result = get_prediction(
image="demo_data/small-vehicles1.jpeg",
detection_model=detection_model,
shift_amount=[0, 0], # No shift applied
full_shape=None, # Full image shape is not provided
postprocess=None, # Postprocess disabled
verbose=1, # Enable verbose output
exclude_classes_by_name=exclude_classes_by_name # Exclude 'car'
)

print("\nFiltered Results from `get_prediction` (First 5 Predictions):")
for obj in result.object_prediction_list[:5]:
print(f"Class ID: {obj.category.id}, Class Name: {obj.category.name}, Score: {obj.score}")

# 6. Demonstrate `get_sliced_prediction` with and without filtering
print("\n===== Testing `get_sliced_prediction` (Without Filtering) =====")
result = get_sliced_prediction(
image="demo_data/small-vehicles1.jpeg",
detection_model=detection_model,
slice_height=256, # Slice height
slice_width=256, # Slice width
overlap_height_ratio=0.2, # Overlap height ratio
overlap_width_ratio=0.2, # Overlap width ratio
verbose=1, # Enable verbose output
)
print("\nNon-Filtered Results from `get_sliced_prediction` (First 5 Predictions):")
for obj in result.object_prediction_list[:5]:
print(f"Class ID: {obj.category.id}, Class Name: {obj.category.name}, Score: {obj.score}")

print("\n===== Testing `get_sliced_prediction` (With Filtering) =====")
result = get_sliced_prediction(
image="demo_data/small-vehicles1.jpeg",
detection_model=detection_model,
slice_height=256,
slice_width=256,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2,
verbose=1,
exclude_classes_by_name=exclude_classes_by_name # Exclude 'car'
)
print("\nFiltered Results from `get_sliced_prediction` (First 5 Predictions):")
for obj in result.object_prediction_list[:5]:
print(f"Class ID: {obj.category.id}, Class Name: {obj.category.name}, Score: {obj.score}")

# 7. Demonstrate `predict` with filtering for a single image
print("\n===== Testing `predict` =====")
predict(
detection_model=detection_model,
source="demo_data/small-vehicles1.jpeg", # Single image source
project="runs/test_predict", # Output project directory
name="exclude_test", # Run name
verbose=1, # Enable verbose output
exclude_classes_by_name=exclude_classes_by_name # Exclude 'car'
)
print("\nFiltered results from `predict` saved in 'runs/test_predict/exclude_test'")
Loading