Skip to content

Commit

Permalink
update functions docstrings and type hinting (#1016)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcakyon authored Apr 8, 2024
1 parent a40db33 commit efe32cb
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 51 deletions.
2 changes: 1 addition & 1 deletion sahi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.11.16"
__version__ = "0.11.17"

from sahi.annotation import BoundingBox, Category, Mask
from sahi.auto_model import AutoDetectionModel
Expand Down
190 changes: 145 additions & 45 deletions sahi/utils/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@


class Colors:
# color palette
def __init__(self):
hex = (
"FF3838",
Expand All @@ -45,16 +44,38 @@ def __init__(self):
"FF95C8",
"FF37C7",
)
self.palette = [self.hex2rgb("#" + c) for c in hex]
self.palette = [self.hex_to_rgb("#" + c) for c in hex]
self.n = len(self.palette)

def __call__(self, i, bgr=False):
c = self.palette[int(i) % self.n]
return (c[2], c[1], c[0]) if bgr else c
def __call__(self, ind, bgr: bool = False):
"""
Convert an index to a color code.
Args:
ind (int): The index to convert.
bgr (bool, optional): Whether to return the color code in BGR format. Defaults to False.
Returns:
tuple: The color code in RGB or BGR format, depending on the value of `bgr`.
"""
color_codes = self.palette[int(ind) % self.n]
return (color_codes[2], color_codes[1], color_codes[0]) if bgr else color_codes

@staticmethod
def hex2rgb(h): # rgb order
return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
def hex_to_rgb(hex_code):
"""
Converts a hexadecimal color code to RGB format.
Args:
hex_code (str): The hexadecimal color code to convert.
Returns:
tuple: A tuple representing the RGB values in the order (R, G, B).
"""
rgb = []
for i in (0, 2, 4):
rgb.append(int(hex_code[1 + i : 1 + i + 2], 16))
return tuple(rgb)


def crop_object_predictions(
Expand All @@ -65,23 +86,25 @@ def crop_object_predictions(
export_format: str = "png",
):
"""
Crops bounding boxes over the source image and exports it to output folder.
Arguments:
object_predictions: a list of prediction.ObjectPrediction
output_dir: directory for resulting visualization to be exported
file_name: exported file will be saved as: output_dir+file_name+".png"
export_format: can be specified as 'jpg' or 'png'
Crops bounding boxes over the source image and exports it to the output folder.
Args:
image (np.ndarray): The source image to crop bounding boxes from.
object_prediction_list: A list of object predictions.
output_dir (str): The directory where the resulting visualizations will be exported. Defaults to an empty string.
file_name (str): The name of the exported file. The exported file will be saved as `output_dir + file_name + ".png"`. Defaults to "prediction_visual".
export_format (str): The format of the exported file. Can be specified as 'jpg' or 'png'. Defaults to "png".
"""
# create output folder if not present
Path(output_dir).mkdir(parents=True, exist_ok=True)
# add bbox and mask to image if present
for ind, object_prediction in enumerate(object_prediction_list):
# deepcopy object_prediction_list so that original is not altered
# deepcopy object_prediction_list so that the original is not altered
object_prediction = object_prediction.deepcopy()
bbox = object_prediction.bbox.to_xyxy()
category_id = object_prediction.category.id
# crop detections
# deepcopy crops so that original is not altered
# deepcopy crops so that the original is not altered
cropped_img = copy.deepcopy(
image[
int(bbox[1]) : int(bbox[3]),
Expand All @@ -98,7 +121,12 @@ def crop_object_predictions(

def convert_image_to(read_path, extension: str = "jpg", grayscale: bool = False):
"""
Reads image from path and saves as given extension.
Reads an image from the given path and saves it with the specified extension.
Args:
read_path (str): The path to the image file.
extension (str, optional): The desired file extension for the saved image. Defaults to "jpg".
grayscale (bool, optional): Whether to convert the image to grayscale. Defaults to False.
"""
image = cv2.imread(read_path)
pre, ext = os.path.splitext(read_path)
Expand All @@ -110,6 +138,17 @@ def convert_image_to(read_path, extension: str = "jpg", grayscale: bool = False)


def read_large_image(image_path: str):
"""
Reads a large image from the specified image path.
Args:
image_path (str): The path to the image file.
Returns:
tuple: A tuple containing the image data and a flag indicating whether cv2 was used to read the image.
The image data is a numpy array representing the image in RGB format.
The flag is True if cv2 was used, False otherwise.
"""
use_cv2 = True
# read image, cv2 fails on large files
try:
Expand All @@ -130,7 +169,13 @@ def read_large_image(image_path: str):

def read_image(image_path: str):
"""
Loads image as numpy array from given path.
Loads image as a numpy array from the given path.
Args:
image_path (str): The path to the image file.
Returns:
numpy.ndarray: The loaded image as a numpy array.
"""
# read image
image = cv2.imread(image_path)
Expand All @@ -144,7 +189,12 @@ def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool
Loads an image as PIL.Image.Image.
Args:
image : Can be image path or url (str), numpy image (np.ndarray) or PIL.Image
image (Union[Image.Image, str, np.ndarray]): The image to be loaded. It can be an image path or URL (str),
a numpy image (np.ndarray), or a PIL.Image object.
exif_fix (bool, optional): Whether to apply an EXIF fix to the image. Defaults to False.
Returns:
PIL.Image.Image: The loaded image as a PIL.Image object.
"""
# https://stackoverflow.com/questions/56174099/how-to-load-images-larger-than-max-image-pixels-with-pil
Image.MAX_IMAGE_PIXELS = None
Expand Down Expand Up @@ -184,7 +234,11 @@ def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool

def select_random_color():
"""
Selects random color.
Selects a random color from a predefined list of colors.
Returns:
list: A list representing the RGB values of the selected color.
"""
colors = [
[0, 255, 0],
Expand All @@ -205,6 +259,13 @@ def select_random_color():
def apply_color_mask(image: np.ndarray, color: tuple):
"""
Applies color mask to given input image.
Args:
image (np.ndarray): The input image to apply the color mask to.
color (tuple): The RGB color tuple to use for the mask.
Returns:
np.ndarray: The resulting image with the applied color mask.
"""
r = np.zeros_like(image).astype(np.uint8)
g = np.zeros_like(image).astype(np.uint8)
Expand Down Expand Up @@ -328,6 +389,22 @@ def visualize_prediction(
"""
Visualizes prediction classes, bounding boxes over the source image
and exports it to output folder.
Args:
image (np.ndarray): The source image.
boxes (List[List]): List of bounding boxes coordinates.
classes (List[str]): List of class labels corresponding to each bounding box.
masks (Optional[List[np.ndarray]], optional): List of masks corresponding to each bounding box. Defaults to None.
rect_th (float, optional): Thickness of the bounding box rectangle. Defaults to None.
text_size (float, optional): Size of the text for class labels. Defaults to None.
text_th (float, optional): Thickness of the text for class labels. Defaults to None.
color (tuple, optional): Color of the bounding box and text. Defaults to None.
hide_labels (bool, optional): Whether to hide the class labels. Defaults to False.
output_dir (Optional[str], optional): Output directory to save the visualization. Defaults to None.
file_name (Optional[str], optional): File name for the saved visualization. Defaults to "prediction_visual".
Returns:
dict: A dictionary containing the visualized image and the elapsed time for the visualization process.
"""
elapsed_time = time.time()
# deepcopy image so that original is not altered
Expand All @@ -354,37 +431,39 @@ def visualize_prediction(
image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0)

# add bboxes to image if present
for i in range(len(boxes)):
for box_indice in range(len(boxes)):
# deepcopy boxso that original is not altered
box = copy.deepcopy(boxes[i])
class_ = classes[i]
box = copy.deepcopy(boxes[box_indice])
class_ = classes[box_indice]

# set color
if colors is not None:
color = colors(class_)
# set bbox points
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
point1, point2 = [int(box[0]), int(box[1])], [int(box[2]), int(box[3])]
# visualize boxes
cv2.rectangle(
image,
p1,
p2,
point1,
point2,
color=color,
thickness=rect_th,
)

if not hide_labels:
# arange bounding box text location
label = f"{class_}"
w, h = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0] # label width, height
outside = p1[1] - h - 3 >= 0 # label fits outside box
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[
0
] # label width, height
outside = point1[1] - box_height - 3 >= 0 # label fits outside box
point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3
# add bounding box text
cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled
cv2.rectangle(image, point1, point2, color, -1, cv2.LINE_AA) # filled
cv2.putText(
image,
label,
(p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
(point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2),
0,
text_size,
(255, 255, 255),
Expand Down Expand Up @@ -417,7 +496,8 @@ def visualize_object_predictions(
"""
Visualizes prediction category names, bounding boxes over the source image
and exports it to output folder.
Arguments:
Args:
object_prediction_list: a list of prediction.ObjectPrediction
rect_th: rectangle thickness
text_size: size of the category name over box
Expand Down Expand Up @@ -472,12 +552,12 @@ def visualize_object_predictions(
if colors is not None:
color = colors(object_prediction.category.id)
# set bbox points
p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
point1, point2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
# visualize boxes
cv2.rectangle(
image,
p1,
p2,
point1,
point2,
color=color,
thickness=rect_th,
)
Expand All @@ -489,15 +569,17 @@ def visualize_object_predictions(
if not hide_conf:
label += f" {score:.2f}"

w, h = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0] # label width, height
outside = p1[1] - h - 3 >= 0 # label fits outside box
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[
0
] # label width, height
outside = point1[1] - box_height - 3 >= 0 # label fits outside box
point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3
# add bounding box text
cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled
cv2.rectangle(image, point1, point2, color, -1, cv2.LINE_AA) # filled
cv2.putText(
image,
label,
(p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
(point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2),
0,
text_size,
(255, 255, 255),
Expand Down Expand Up @@ -541,9 +623,17 @@ def get_coco_segmentation_from_bool_mask(bool_mask):
return coco_segmentation


def get_bool_mask_from_coco_segmentation(coco_segmentation, width, height):
def get_bool_mask_from_coco_segmentation(coco_segmentation: List[List[float]], width: int, height: int) -> np.ndarray:
"""
Convert coco segmentation to 2D boolean mask of given height and width
Parameters:
- coco_segmentation: list of points representing the coco segmentation
- width: width of the boolean mask
- height: height of the boolean mask
Returns:
- bool_mask: 2D boolean mask of size (height, width)
"""
size = [height, width]
points = [np.array(point).reshape(-1, 2).round().astype(int) for point in coco_segmentation]
Expand All @@ -553,9 +643,15 @@ def get_bool_mask_from_coco_segmentation(coco_segmentation, width, height):
return bool_mask


def get_bbox_from_bool_mask(bool_mask):
def get_bbox_from_bool_mask(bool_mask: np.ndarray) -> Optional[List[int]]:
"""
Generate voc bbox ([xmin, ymin, xmax, ymax]) from given bool_mask (2D np.ndarray)
Generate VOC bounding box [xmin, ymin, xmax, ymax] from given boolean mask.
Args:
bool_mask (np.ndarray): 2D boolean mask.
Returns:
Optional[List[int]]: VOC bounding box [xmin, ymin, xmax, ymax] or None if no bounding box is found.
"""
rows = np.any(bool_mask, axis=1)
cols = np.any(bool_mask, axis=0)
Expand Down Expand Up @@ -596,12 +692,16 @@ def ipython_display(image: np.ndarray):
IPython.display.display(i)


def exif_transpose(image: Image.Image):
def exif_transpose(image: Image.Image) -> Image.Image:
"""
Transpose a PIL image accordingly if it has an EXIF Orientation tag.
Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
:param image: The image to transpose.
:return: An image.
Args:
image (Image.Image): The image to transpose.
Returns:
Image.Image: The transposed image.
"""
exif = image.getexif()
orientation = exif.get(0x0112, 1) # default 1
Expand Down
Loading

0 comments on commit efe32cb

Please sign in to comment.