-
Notifications
You must be signed in to change notification settings - Fork 614
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Mark Merling <[email protected]> Co-authored-by: fatih c. akyon <[email protected]>
- Loading branch information
1 parent
e37e8e1
commit 4ed7f3d
Showing
5 changed files
with
177 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import unittest | ||
from unittest.mock import patch | ||
|
||
import numpy as np | ||
from PIL import Image | ||
|
||
from sahi.utils.cv import ( | ||
Colors, | ||
apply_color_mask, | ||
exif_transpose, | ||
get_bbox_from_bool_mask, | ||
get_coco_segmentation_from_bool_mask, | ||
read_image, | ||
) | ||
|
||
|
||
class TestCvUtils(unittest.TestCase): | ||
def test_hex_to_rgb(self): | ||
colors = Colors() | ||
self.assertEqual(colors.hex_to_rgb("#FF3838"), (255, 56, 56)) | ||
|
||
def test_hex_to_rgb_retrieve(self): | ||
colors = Colors() | ||
self.assertEqual(colors(0), (255, 56, 56)) | ||
|
||
@patch("sahi.utils.cv.cv2.cvtColor") | ||
@patch("sahi.utils.cv.cv2.imread") | ||
def test_read_image(self, mock_imread, mock_cvtColor): | ||
fake_image = "test.jpg" | ||
fake_image_val = np.array([[[10, 20, 30]]], dtype=np.uint8) | ||
fake_image_rbg_val = np.array([[[10, 20, 30]]], dtype=np.uint8) | ||
mock_imread.return_value = fake_image_val | ||
mock_cvtColor.return_value = fake_image_rbg_val | ||
|
||
result = read_image(fake_image) | ||
|
||
# mock_cv2.assert_called_once_with(fake_image) | ||
mock_imread.assert_called_once_with(fake_image) | ||
np.testing.assert_array_equal(result, fake_image_rbg_val) | ||
|
||
def test_apply_color_mask(self): | ||
image = np.array([[0, 1]], dtype=np.uint8) | ||
color = (255, 0, 0) | ||
|
||
expected_output = np.array([[[0, 0, 0], [255, 0, 0]]], dtype=np.uint8) | ||
|
||
result = apply_color_mask(image, color) | ||
|
||
np.testing.assert_array_equal(result, expected_output) | ||
|
||
def test_get_coco_segmentation_from_bool_mask_simple(self): | ||
mask = np.zeros((10, 10), dtype=bool) | ||
result = get_coco_segmentation_from_bool_mask(mask) | ||
self.assertEqual(result, []) | ||
|
||
def test_get_coco_segmentation_from_bool_mask_polygon(self): | ||
mask = np.zeros((10, 20), dtype=bool) | ||
mask[1:4, 1:4] = True | ||
mask[5:8, 5:8] = True | ||
result = get_coco_segmentation_from_bool_mask(mask) | ||
self.assertEqual(len(result), 2) | ||
|
||
def test_get_bbox_from_bool_mask(self): | ||
mask = np.array( | ||
[ | ||
[False, False, False], | ||
[False, True, True], | ||
[False, True, True], | ||
[False, False, False], | ||
] | ||
) | ||
expected_result = [1, 1, 2, 2] | ||
result = get_bbox_from_bool_mask(mask) | ||
self.assertEqual(result, expected_result) | ||
|
||
def test_exif_transpose_simple(self): | ||
test_image = Image.new("RGB", (100, 100), color="red") | ||
transposed = exif_transpose(test_image) | ||
self.assertEqual(transposed, test_image) | ||
|
||
def test_exif_transpose_non_standard(self): | ||
test_image = Image.new("RGB", (100, 100), color="red") | ||
exif = test_image.getexif() | ||
exif[0x0112] = 9 | ||
test_image.info["exif"] = exif.tobytes() | ||
transposed = exif_transpose(test_image) | ||
self.assertEqual(transposed, test_image) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
from sahi.postprocess.utils import ObjectPredictionList | ||
|
||
|
||
class TestPostprocessUtils(unittest.TestCase): | ||
def setUp(self): | ||
self.test_input = [ObjectPredictionList([1, 2, 3, 4])] | ||
|
||
def test_get_item_int(self): | ||
obj = self.test_input[0] | ||
self.assertEqual(obj[0].tolist(), 1) | ||
|
||
def test_len(self): | ||
obj = self.test_input[0] | ||
self.assertEqual(len(obj), 4) | ||
|
||
def test_extend(self): | ||
obj = self.test_input[0] | ||
obj.extend(ObjectPredictionList([torch.randn(1, 2, 3, 4)])) | ||
self.assertEqual(len(obj), 5) | ||
|
||
def test_tostring(self): | ||
obj = self.test_input[0] | ||
self.assertEqual(str(obj), str([1, 2, 3, 4])) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
from sahi.prediction import PredictionScore | ||
|
||
|
||
class TestPrediction(unittest.TestCase): | ||
def test_prediction_score(self): | ||
|
||
prediction_score = PredictionScore(np.array(0.6)) | ||
self.assertEqual(type(prediction_score.value), float) | ||
self.assertEqual(prediction_score.is_greater_than_threshold(0.5), True) | ||
self.assertEqual(prediction_score.is_greater_than_threshold(0.7), False) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from sahi.utils.torch import empty_cuda_cache, to_float_tensor, torch_to_numpy | ||
|
||
|
||
class TestTorchUtils(unittest.TestCase): | ||
def test_empty_cuda_cache(self): | ||
if torch.cuda.is_available(): | ||
self.assertIsNone(empty_cuda_cache()) | ||
|
||
def test_to_float_tensor(self): | ||
|
||
img = to_float_tensor(np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8)) | ||
self.assertEqual(img.shape, (3, 10, 10)) | ||
|
||
def test_torch_to_numpy(self): | ||
img_t = torch.tensor(np.random.rand(3, 10, 10)) | ||
img = torch_to_numpy(img_t) | ||
self.assertEqual(img.shape, (10, 10, 3)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |