Skip to content

Commit a4ff124

Browse files
Internal change
PiperOrigin-RevId: 471911078
1 parent a4ae812 commit a4ff124

File tree

3 files changed

+63
-115
lines changed

3 files changed

+63
-115
lines changed

official/vision/ops/box_ops.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,3 +783,66 @@ def box_matching(boxes, gt_boxes, gt_classes):
783783

784784
return (matched_gt_boxes, matched_gt_classes, matched_gt_indices,
785785
matched_iou, iou)
786+
787+
788+
def bbox2mask(bbox: tf.Tensor,
789+
*,
790+
image_height: int,
791+
image_width: int,
792+
dtype: tf.DType = tf.bool) -> tf.Tensor:
793+
"""Converts bounding boxes to bitmasks.
794+
795+
Args:
796+
bbox: A tensor in shape (..., 4) with arbitrary numbers of batch dimensions,
797+
representing the absolute coordinates (ymin, xmin, ymax, xmax) for each
798+
bounding box.
799+
image_height: an integer representing the height of the image.
800+
image_width: an integer representing the width of the image.
801+
dtype: DType of the output bitmasks.
802+
803+
Returns:
804+
A tensor in shape (..., height, width) which stores the bitmasks created
805+
from the bounding boxes. For example:
806+
807+
>>> bbox2mask(tf.constant([[1,2,4,4]]),
808+
image_height=5,
809+
image_width=5,
810+
dtype=tf.int32)
811+
<tf.Tensor: shape=(1, 5, 5), dtype=int32, numpy=
812+
array([[[0, 0, 0, 0, 0],
813+
[0, 0, 1, 1, 0],
814+
[0, 0, 1, 1, 0],
815+
[0, 0, 1, 1, 0],
816+
[0, 0, 0, 0, 0]]], dtype=int32)>
817+
"""
818+
bbox_shape = bbox.get_shape().as_list()
819+
if bbox_shape[-1] != 4:
820+
raise ValueError(
821+
'Expected the last dimension of `bbox` has size == 4, but the shape '
822+
'of `bbox` was: %s' % bbox_shape)
823+
824+
# (..., 1)
825+
ymin = bbox[..., 0:1]
826+
xmin = bbox[..., 1:2]
827+
ymax = bbox[..., 2:3]
828+
xmax = bbox[..., 3:4]
829+
# (..., 1, width)
830+
ymin = tf.expand_dims(tf.repeat(ymin, repeats=image_width, axis=-1), axis=-2)
831+
# (..., height, 1)
832+
xmin = tf.expand_dims(tf.repeat(xmin, repeats=image_height, axis=-1), axis=-1)
833+
# (..., 1, width)
834+
ymax = tf.expand_dims(tf.repeat(ymax, repeats=image_width, axis=-1), axis=-2)
835+
# (..., height, 1)
836+
xmax = tf.expand_dims(tf.repeat(xmax, repeats=image_height, axis=-1), axis=-1)
837+
838+
# (height, 1)
839+
y_grid = tf.expand_dims(tf.range(image_height, dtype=bbox.dtype), axis=-1)
840+
# (1, width)
841+
x_grid = tf.expand_dims(tf.range(image_width, dtype=bbox.dtype), axis=-2)
842+
843+
# (..., height, width)
844+
ymin_mask = y_grid >= ymin
845+
xmin_mask = x_grid >= xmin
846+
ymax_mask = y_grid < ymax
847+
xmax_mask = x_grid < xmax
848+
return tf.cast(ymin_mask & xmin_mask & ymax_mask & xmax_mask, dtype)

official/vision/ops/mask_ops.py

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
# Import libraries
1919
import cv2
2020
import numpy as np
21-
import tensorflow as tf
2221

2322

2423
def paste_instance_masks(masks: np.ndarray, detected_boxes: np.ndarray,
@@ -184,66 +183,3 @@ def paste_instance_masks_v2(masks: np.ndarray, detected_boxes: np.ndarray,
184183

185184
segms = np.array(segms)
186185
return segms
187-
188-
189-
def bbox2mask(bbox: tf.Tensor,
190-
*,
191-
image_height: int,
192-
image_width: int,
193-
dtype: tf.DType = tf.bool) -> tf.Tensor:
194-
"""Converts bounding boxes to bitmasks.
195-
196-
Args:
197-
bbox: A tensor in shape (..., 4) with arbitrary numbers of batch dimensions,
198-
representing the absolute coordinates (ymin, xmin, ymax, xmax) for each
199-
bounding box.
200-
image_height: an integer representing the height of the image.
201-
image_width: an integer representing the width of the image.
202-
dtype: DType of the output bitmasks.
203-
204-
Returns:
205-
A tensor in shape (..., height, width) which stores the bitmasks created
206-
from the bounding boxes. For example:
207-
208-
>>> bbox2mask(tf.constant([[1,2,4,4]]),
209-
image_height=5,
210-
image_width=5,
211-
dtype=tf.int32)
212-
<tf.Tensor: shape=(1, 5, 5), dtype=int32, numpy=
213-
array([[[0, 0, 0, 0, 0],
214-
[0, 0, 1, 1, 0],
215-
[0, 0, 1, 1, 0],
216-
[0, 0, 1, 1, 0],
217-
[0, 0, 0, 0, 0]]], dtype=int32)>
218-
"""
219-
bbox_shape = bbox.get_shape().as_list()
220-
if bbox_shape[-1] != 4:
221-
raise ValueError(
222-
'Expected the last dimension of `bbox` has size == 4, but the shape '
223-
'of `bbox` was: %s' % bbox_shape)
224-
225-
# (..., 1)
226-
ymin = bbox[..., 0:1]
227-
xmin = bbox[..., 1:2]
228-
ymax = bbox[..., 2:3]
229-
xmax = bbox[..., 3:4]
230-
# (..., 1, width)
231-
ymin = tf.expand_dims(tf.repeat(ymin, repeats=image_width, axis=-1), axis=-2)
232-
# (..., height, 1)
233-
xmin = tf.expand_dims(tf.repeat(xmin, repeats=image_height, axis=-1), axis=-1)
234-
# (..., 1, width)
235-
ymax = tf.expand_dims(tf.repeat(ymax, repeats=image_width, axis=-1), axis=-2)
236-
# (..., height, 1)
237-
xmax = tf.expand_dims(tf.repeat(xmax, repeats=image_height, axis=-1), axis=-1)
238-
239-
# (height, 1)
240-
y_grid = tf.expand_dims(tf.range(image_height, dtype=bbox.dtype), axis=-1)
241-
# (1, width)
242-
x_grid = tf.expand_dims(tf.range(image_width, dtype=bbox.dtype), axis=-2)
243-
244-
# (..., height, width)
245-
ymin_mask = y_grid >= ymin
246-
xmin_mask = x_grid >= xmin
247-
ymax_mask = y_grid < ymax
248-
xmax_mask = x_grid < xmax
249-
return tf.cast(ymin_mask & xmin_mask & ymax_mask & xmax_mask, dtype)

official/vision/ops/mask_ops_test.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -49,57 +49,6 @@ def testPasteInstanceMasksV2(self):
4949
np.array(masks > 0.5, dtype=np.uint8),
5050
1e-5)
5151

52-
def testBbox2mask(self):
53-
bboxes = tf.constant([[1, 2, 4, 4], [-1, -1, 3, 3], [2, 3, 6, 8],
54-
[1, 1, 2, 2], [1, 1, 1, 4]])
55-
masks = mask_ops.bbox2mask(
56-
bboxes, image_height=5, image_width=6, dtype=tf.int32)
57-
expected_masks = tf.constant(
58-
[
59-
[ # bbox = [1, 2, 4, 4]
60-
[0, 0, 0, 0, 0, 0],
61-
[0, 0, 1, 1, 0, 0],
62-
[0, 0, 1, 1, 0, 0],
63-
[0, 0, 1, 1, 0, 0],
64-
[0, 0, 0, 0, 0, 0],
65-
],
66-
[ # bbox = [-1, -1, 3, 3]
67-
[1, 1, 1, 0, 0, 0],
68-
[1, 1, 1, 0, 0, 0],
69-
[1, 1, 1, 0, 0, 0],
70-
[0, 0, 0, 0, 0, 0],
71-
[0, 0, 0, 0, 0, 0],
72-
],
73-
[ # bbox = [2, 3, 6, 8]
74-
[0, 0, 0, 0, 0, 0],
75-
[0, 0, 0, 0, 0, 0],
76-
[0, 0, 0, 1, 1, 1],
77-
[0, 0, 0, 1, 1, 1],
78-
[0, 0, 0, 1, 1, 1],
79-
],
80-
[ # bbox = [1, 1, 2, 2]
81-
[0, 0, 0, 0, 0, 0],
82-
[0, 1, 0, 0, 0, 0],
83-
[0, 0, 0, 0, 0, 0],
84-
[0, 0, 0, 0, 0, 0],
85-
[0, 0, 0, 0, 0, 0],
86-
],
87-
[ # bbox = [1, 1, 1, 4]
88-
[0, 0, 0, 0, 0, 0],
89-
[0, 0, 0, 0, 0, 0],
90-
[0, 0, 0, 0, 0, 0],
91-
[0, 0, 0, 0, 0, 0],
92-
[0, 0, 0, 0, 0, 0],
93-
]
94-
],
95-
dtype=tf.int32)
96-
self.assertAllEqual(expected_masks, masks)
97-
98-
def testBbox2maskInvalidInput(self):
99-
bboxes = tf.constant([[1, 2, 4, 4, 4], [-1, -1, 3, 3, 3]])
100-
with self.assertRaisesRegex(ValueError, 'bbox.*size == 4'):
101-
mask_ops.bbox2mask(bboxes, image_height=5, image_width=6, dtype=tf.int32)
102-
10352

10453
if __name__ == '__main__':
10554
tf.test.main()

0 commit comments

Comments
 (0)