|
9 | 9 | Covers: |
10 | 10 | * ``_create_interpreter()`` — interpreter loading with tflite_runtime / tensorflow fallback |
11 | 11 | * ``_run_inference()`` — image preprocessing, invocation, and detection decoding |
| 12 | +* ``_decode_masks()`` — segmentation mask upsampling and thresholding |
12 | 13 | """ |
13 | 14 |
|
14 | 15 | from __future__ import annotations |
|
22 | 23 | import supervision as sv |
23 | 24 | from PIL import Image as PILImage |
24 | 25 |
|
25 | | -from rfdetr.export._tflite.inference import _create_interpreter, _run_inference |
| 26 | +from rfdetr.export._tflite.inference import _create_interpreter, _decode_masks, _run_inference |
26 | 27 |
|
27 | 28 | # --------------------------------------------------------------------------- |
28 | 29 | # Shared helpers / factories |
@@ -439,3 +440,147 @@ def _get_tensor(index: int) -> np.ndarray: |
439 | 440 | dets, _ = _run_inference(interp, rgb_image, threshold=0.3) |
440 | 441 | assert isinstance(dets, sv.Detections) |
441 | 442 | assert len(dets) >= 1 |
| 443 | + |
| 444 | + |
| 445 | +# --------------------------------------------------------------------------- |
| 446 | +# TestMaskDecoding |
| 447 | +# --------------------------------------------------------------------------- |
| 448 | + |
| 449 | + |
| 450 | +class TestMaskDecoding: |
| 451 | + """Tests for ``_decode_masks()`` and mask decoding in ``_run_inference()``.""" |
| 452 | + |
| 453 | + @pytest.fixture() |
| 454 | + def rgb_image(self, tmp_path: Path) -> Path: |
| 455 | + """Write a small RGB JPEG to a temp file and return its path.""" |
| 456 | + p = tmp_path / "image.jpg" |
| 457 | + _save_rgb_image(p) |
| 458 | + return p |
| 459 | + |
| 460 | + def test_decode_masks_shape_and_dtype(self) -> None: |
| 461 | + """Output shape is (K, height, width) from out_size=(width, height); dtype is bool.""" |
| 462 | + out = _decode_masks(np.zeros((3, 10, 10), dtype=np.float32), (40, 20)) |
| 463 | + assert out.shape == (3, 20, 40) |
| 464 | + assert out.dtype == bool |
| 465 | + |
| 466 | + def test_decode_masks_thresholds_at_zero(self) -> None: |
| 467 | + """Positive logits decode to True, negative logits to False.""" |
| 468 | + logits = np.stack( |
| 469 | + [ |
| 470 | + np.full((8, 8), 5.0, dtype=np.float32), |
| 471 | + np.full((8, 8), -5.0, dtype=np.float32), |
| 472 | + ] |
| 473 | + ) |
| 474 | + out = _decode_masks(logits, (16, 16)) |
| 475 | + assert out[0].all() |
| 476 | + assert not out[1].any() |
| 477 | + |
| 478 | + def test_decode_masks_empty_input(self) -> None: |
| 479 | + """Zero masks in yields a (0, height, width) array, not an error.""" |
| 480 | + out = _decode_masks(np.zeros((0, 10, 10), dtype=np.float32), (32, 32)) |
| 481 | + assert out.shape == (0, 32, 32) |
| 482 | + |
| 483 | + def test_run_inference_decodes_masks_for_seg_model(self, rgb_image: Path) -> None: |
| 484 | + """A 3-output segmentation export populates Detections.mask at image size.""" |
| 485 | + boxes = _make_boxes() |
| 486 | + logits = _make_logits(high_conf_idx=0) |
| 487 | + masks = np.full((1, 10, 28, 28), -10.0, dtype=np.float32) |
| 488 | + masks[0, 0] = 10.0 # query 0 (the kept detection) gets an all-positive mask |
| 489 | + |
| 490 | + def _get_tensor(index: int) -> np.ndarray: |
| 491 | + return {1: boxes, 2: logits, 3: masks}[index] |
| 492 | + |
| 493 | + interp = mock.MagicMock() |
| 494 | + interp.get_input_details.return_value = [{"shape": _INPUT_SHAPE, "index": 0, "dtype": np.float32}] |
| 495 | + interp.get_output_details.return_value = [ |
| 496 | + {"shape": [1, 10, 4], "name": "Identity_0", "index": 1}, |
| 497 | + {"shape": [1, 10, 82], "name": "Identity_1", "index": 2}, |
| 498 | + {"shape": [1, 10, 28, 28], "name": "Identity_2", "index": 3}, |
| 499 | + ] |
| 500 | + interp.get_tensor.side_effect = _get_tensor |
| 501 | + |
| 502 | + dets, img = _run_inference(interp, rgb_image, threshold=0.3) |
| 503 | + assert dets.mask is not None |
| 504 | + assert dets.mask.shape == (len(dets), img.height, img.width) |
| 505 | + assert dets.mask.dtype == bool |
| 506 | + assert dets.mask[0].all() # query 0's all-positive logits decode to a full mask |
| 507 | + |
| 508 | + def test_run_inference_no_mask_for_detection_model(self, rgb_image: Path) -> None: |
| 509 | + """A 2-output detection export leaves Detections.mask as None.""" |
| 510 | + interp = _make_interp(logits=_make_logits(high_conf_idx=0)) |
| 511 | + dets, _ = _run_inference(interp, rgb_image, threshold=0.3) |
| 512 | + assert dets.mask is None |
| 513 | + |
| 514 | + def test_run_inference_name_based_mask_detection(self, rgb_image: Path) -> None: |
| 515 | + """Output named 'masks:0' exercises the name-based path and sets Detections.mask.""" |
| 516 | + boxes = _make_boxes() |
| 517 | + logits = _make_logits(high_conf_idx=0) |
| 518 | + masks = np.full((1, 10, 28, 28), 10.0, dtype=np.float32) |
| 519 | + |
| 520 | + def _get_tensor(index: int) -> np.ndarray: |
| 521 | + return {1: boxes, 2: logits, 3: masks}[index] |
| 522 | + |
| 523 | + interp = mock.MagicMock() |
| 524 | + interp.get_input_details.return_value = [{"shape": _INPUT_SHAPE, "index": 0, "dtype": np.float32}] |
| 525 | + interp.get_output_details.return_value = [ |
| 526 | + {"shape": [1, 10, 4], "name": "serving_default_dets:0", "index": 1}, |
| 527 | + {"shape": [1, 10, 82], "name": "serving_default_labels:0", "index": 2}, |
| 528 | + {"shape": [1, 10, 28, 28], "name": "serving_default_masks:0", "index": 3}, |
| 529 | + ] |
| 530 | + interp.get_tensor.side_effect = _get_tensor |
| 531 | + |
| 532 | + dets, _ = _run_inference(interp, rgb_image, threshold=0.3) |
| 533 | + assert dets.mask is not None |
| 534 | + |
| 535 | + def test_run_inference_seg_model_no_detections_returns_none_mask(self, rgb_image: Path) -> None: |
| 536 | + """Seg model with all scores below threshold returns mask=None (keep.any() is False).""" |
| 537 | + boxes = _make_boxes() |
| 538 | + logits = _make_logits(high_conf_idx=None) # all scores near zero, below threshold |
| 539 | + masks = np.full((1, 10, 28, 28), 10.0, dtype=np.float32) |
| 540 | + |
| 541 | + def _get_tensor(index: int) -> np.ndarray: |
| 542 | + return {1: boxes, 2: logits, 3: masks}[index] |
| 543 | + |
| 544 | + interp = mock.MagicMock() |
| 545 | + interp.get_input_details.return_value = [{"shape": _INPUT_SHAPE, "index": 0, "dtype": np.float32}] |
| 546 | + interp.get_output_details.return_value = [ |
| 547 | + {"shape": [1, 10, 4], "name": "Identity_0", "index": 1}, |
| 548 | + {"shape": [1, 10, 82], "name": "Identity_1", "index": 2}, |
| 549 | + {"shape": [1, 10, 28, 28], "name": "Identity_2", "index": 3}, |
| 550 | + ] |
| 551 | + interp.get_tensor.side_effect = _get_tensor |
| 552 | + |
| 553 | + dets, _ = _run_inference(interp, rgb_image, threshold=0.3) |
| 554 | + assert len(dets) == 0 |
| 555 | + assert dets.mask is None |
| 556 | + |
| 557 | + def test_decode_masks_raises_on_wrong_rank(self) -> None: |
| 558 | + """_decode_masks raises ValueError when input is not rank-3.""" |
| 559 | + with pytest.raises(ValueError, match="rank-3"): |
| 560 | + _decode_masks(np.zeros((10, 28, 28, 1), dtype=np.float32), (56, 56)) |
| 561 | + |
| 562 | + def test_decode_masks_exact_zero_logit_decodes_to_false(self) -> None: |
| 563 | + """Logit exactly 0.0 is not > 0.0 and decodes to False (strict threshold).""" |
| 564 | + zero_logits = np.zeros((1, 8, 8), dtype=np.float32) |
| 565 | + out = _decode_masks(zero_logits, (16, 16)) |
| 566 | + assert not out.any() |
| 567 | + |
| 568 | + def test_decode_masks_non_square_logit_input(self) -> None: |
| 569 | + """Non-square logit map (K, Hm, Wm) with Hm != Wm resizes to the correct output shape.""" |
| 570 | + logits = np.full((3, 7, 14), 5.0, dtype=np.float32) |
| 571 | + out = _decode_masks(logits, (56, 28)) # out_size=(width=56, height=28) |
| 572 | + assert out.shape == (3, 28, 56) |
| 573 | + assert out.all() # all-positive logits → all True |
| 574 | + |
| 575 | + def test_decode_masks_parity_positive_negative_regions(self) -> None: |
| 576 | + """Positive/negative logit regions map correctly after bilinear upsample + threshold. |
| 577 | +
|
| 578 | + Uses high-magnitude logits (±10) so no ambiguity near the boundary; verifies |
| 579 | + the core _decode_masks contract matches the >0 PostProcess.forward equivalent. |
| 580 | + """ |
| 581 | + logits = np.full((1, 14, 14), -10.0, dtype=np.float32) |
| 582 | + logits[0, :7, :] = 10.0 # top half strongly positive, bottom half strongly negative |
| 583 | + out = _decode_masks(logits, (28, 28)) |
| 584 | + # Interior rows well away from the half-way boundary |
| 585 | + assert out[0, 1:6, :].all() # top rows → all True |
| 586 | + assert not out[0, 15:27, :].any() # bottom rows → all False |
0 commit comments