Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I modified a part of the code to enable parallel inference with multiple num_batch #1113

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions sahi/models/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def perform_inference(self, image: np.ndarray):

if self.image_size is not None:
kwargs = {"imgsz": self.image_size, **kwargs}
if type(image) is list:

prediction_result = self.model(image[:, :, ::-1], **kwargs) # YOLOv8 expects numpy arrays to have BGR

prediction_result = self.model(image, **kwargs) # YOLOv8 expects numpy arrays to have BGR
else :
prediction_result = self.model(image[:, :, ::-1], **kwargs)
if self.has_mask:
if not prediction_result[0].masks:
prediction_result[0].masks = Masks(
Expand Down Expand Up @@ -109,7 +111,10 @@ def perform_inference(self, image: np.ndarray):
prediction_result = [result.boxes.data for result in prediction_result]

self._original_predictions = prediction_result
self._original_shape = image.shape
if type(image) == list:
self._original_shape = image[0].shape
else:
self._original_shape = image.shape

@property
def category_names(self):
Expand Down
51 changes: 33 additions & 18 deletions sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import time
import math
from typing import List, Optional

from sahi.utils.import_utils import is_available
Expand Down Expand Up @@ -87,10 +88,11 @@ def get_prediction(
durations_in_seconds = dict()

# read image as pil
image_as_pil = read_image_as_pil(image)
# image_as_pil = read_image_as_pil(image)
# get prediction
time_start = time.time()
detection_model.perform_inference(np.ascontiguousarray(image_as_pil))
# detection_model.perform_inference(np.ascontiguousarray(image_as_pil))
detection_model.perform_inference(image)
time_end = time.time() - time_start
durations_in_seconds["prediction"] = time_end

Expand All @@ -101,12 +103,10 @@ def get_prediction(
shift_amount=shift_amount,
full_shape=full_shape,
)
object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list

object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list_per_image
# postprocess matching predictions
if postprocess is not None:
object_prediction_list = postprocess(object_prediction_list)

time_end = time.time() - time_start
durations_in_seconds["postprocess"] = time_end

Expand Down Expand Up @@ -139,6 +139,7 @@ def get_sliced_prediction(
auto_slice_resolution: bool = True,
slice_export_prefix: str = None,
slice_dir: str = None,
num_batch: int = 1
) -> PredictionResult:
"""
Function for slice image + get predicion for each slice + combine predictions in full image.
Expand Down Expand Up @@ -198,8 +199,8 @@ def get_sliced_prediction(
# for profiling
durations_in_seconds = dict()

# currently only 1 batch supported
num_batch = 1
# # currently only 1 batch supported
# num_batch = 1
# create slices from full image
time_start = time.time()
slice_image_result = slice_image(
Expand Down Expand Up @@ -233,7 +234,8 @@ def get_sliced_prediction(
)

# create prediction input
num_group = int(num_slices / num_batch)
# num_group = int(num_slices / num_batch)
num_group = math.ceil(num_slices / num_batch)
if verbose == 1 or verbose == 2:
tqdm.write(f"Performing prediction on {num_slices} slices.")
object_prediction_list = []
Expand All @@ -243,22 +245,31 @@ def get_sliced_prediction(
image_list = []
shift_amount_list = []
for image_ind in range(num_batch):
image_list.append(slice_image_result.images[group_ind * num_batch + image_ind])
if (group_ind * num_batch + image_ind) >= num_slices:
break
# image_list.append(slice_image_result.images[group_ind * num_batch + image_ind])
img_slice = slice_image_result.images[group_ind * num_batch + image_ind]
img_slice = img_slice[:,:,::-1]
image_list.append(img_slice)
shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind])
# perform batch prediction
num_full = len(image_list)
prediction_result = get_prediction(
image=image_list[0],
image=image_list,
detection_model=detection_model,
shift_amount=shift_amount_list[0],
full_shape=[
shift_amount=shift_amount_list,
full_shape=[[
slice_image_result.original_image_height,
slice_image_result.original_image_width,
],
]] * num_full,
)

# convert sliced predictions to full predictions
for object_prediction in prediction_result.object_prediction_list:
if object_prediction: # if not empty
object_prediction_list.append(object_prediction.get_shifted_object_prediction())
for object_prediction_per in prediction_result.object_prediction_list:

if len(object_prediction_per) != 0: # if not empty
for object_prediction in object_prediction_per:
object_prediction_list.append(object_prediction.get_shifted_object_prediction())

# merge matching predictions during sliced prediction
if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length:
Expand All @@ -267,7 +278,7 @@ def get_sliced_prediction(
# perform standard prediction
if num_slices > 1 and perform_standard_pred:
prediction_result = get_prediction(
image=image,
image=[np.array(image)],
detection_model=detection_model,
shift_amount=[0, 0],
full_shape=[
Expand All @@ -276,7 +287,9 @@ def get_sliced_prediction(
],
postprocess=None,
)
object_prediction_list.extend(prediction_result.object_prediction_list)
if len(prediction_result.object_prediction_list) != 0:
for _predicion_result in prediction_result.object_prediction_list:
object_prediction_list.extend(_predicion_result)

# merge matching predictions
if len(object_prediction_list) > 1:
Expand Down Expand Up @@ -377,6 +390,7 @@ def predict(
verbose: int = 1,
return_dict: bool = False,
force_postprocess_type: bool = False,
num_batch: int = 1,
**kwargs,
):
"""
Expand Down Expand Up @@ -569,6 +583,7 @@ def predict(
postprocess_match_threshold=postprocess_match_threshold,
postprocess_class_agnostic=postprocess_class_agnostic,
verbose=1 if verbose else 0,
num_batch = num_batch,
)
object_prediction_list = prediction_result.object_prediction_list
durations_in_seconds["slice"] += prediction_result.durations_in_seconds["slice"]
Expand Down
9 changes: 7 additions & 2 deletions sahi/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,13 @@ def __init__(
image: Union[Image.Image, str, np.ndarray],
durations_in_seconds: Optional[Dict] = None,
):
self.image: Image.Image = read_image_as_pil(image)
self.image_width, self.image_height = self.image.size

if type(image) is list:
self.image = image
self.image_width, self.image_height = self.image[0].shape[:2]
else :
self.image: Image.Image = read_image_as_pil(image)
self.image_width, self.image_height = self.image.size
self.object_prediction_list: List[ObjectPrediction] = object_prediction_list
self.durations_in_seconds = durations_in_seconds

Expand Down
Loading