Skip to content

Commit 4277360

Browse files
authored
Merge branch 'develop' into fix-resize-crop
2 parents d443c41 + 8ae6767 commit 4277360

3 files changed

Lines changed: 221 additions & 22 deletions

File tree

src/rfdetr/export/_tflite/converter.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,10 @@
4747
this normalization at inference time.
4848
4949
Note:
50-
**Segmentation model export is not validated.** The same ``convert_kwargs``
51-
are applied to segmentation models as to detection models, but the
52-
segmentation output path introduces additional ops (``ScatterND``,
53-
``Resize``, extra ``GridSample`` calls in the mask resampling path) that
54-
have not been exercised end-to-end through TFLite. Treat segmentation
55-
TFLite export as experimental and verify outputs against the ONNX baseline
56-
before deployment.
50+
Segmentation models additionally emit a ``masks`` output. FP32, FP16,
51+
and dynamic-range INT8 all match the PyTorch baseline closely (INT8 mask
52+
fidelity is marginally lower). Verified on the non-plus segmentation
53+
variants: Nano, Small, Medium, Large, and Preview.
5754
"""
5855

5956
from __future__ import annotations
@@ -479,9 +476,10 @@ def export_tflite(
479476
substituted with TFLite-native pseudo-operators to avoid a missing
480477
TensorFlow Flex delegate at inference time.
481478
482-
Segmentation export (``pred_masks`` output) is **not validated** in
483-
the current implementation; additional operators may need to be
484-
added to ``replace_to_pseudo_operators`` for segmentation models.
479+
Segmentation models additionally emit a ``masks`` output, decoded by
480+
:func:`rfdetr.export._tflite.inference._run_inference`. Verified on
481+
the non-plus segmentation variants (Nano, Small, Medium, Large,
482+
Preview).
485483
"""
486484
onnx_path = Path(onnx_path)
487485
output_dir = Path(output_dir)

src/rfdetr/export/_tflite/inference.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
"""TFLite inference helpers for RF-DETR exported models.
88
99
These functions handle interpreter creation, image preprocessing, and
10-
detection decoding without requiring PyTorch or the RF-DETR training stack —
11-
only ``tflite-runtime`` (or ``tensorflow``), ``numpy``, ``supervision``, and
12-
``Pillow`` are needed at inference time.
10+
decoding of detection and segmentation-mask outputs without requiring PyTorch
11+
or the RF-DETR training stack: only ``tflite-runtime`` (or ``tensorflow``),
12+
``numpy``, ``supervision``, and ``Pillow`` are needed at inference time.
1313
"""
1414

1515
from __future__ import annotations
@@ -19,12 +19,16 @@
1919

2020
import numpy as np
2121
import supervision as sv
22+
from numpy.typing import NDArray
2223
from PIL import Image as PILImage
2324

2425
from rfdetr.utilities.logger import get_logger
2526

2627
logger = get_logger()
2728

29+
# PILImage.Resampling was introduced in Pillow 9.1; fall back to the legacy constant.
30+
_PIL_BILINEAR = getattr(PILImage, "Resampling", PILImage).BILINEAR
31+
2832

2933
def _create_interpreter(model_path: str | Path) -> Any:
3034
"""Load a TFLite model, allocate tensors, and log I/O shapes.
@@ -64,6 +68,38 @@ def _create_interpreter(model_path: str | Path) -> Any:
6468
return interp
6569

6670

71+
def _decode_masks(mask_logits: NDArray[Any], out_size: tuple[int, int]) -> NDArray[np.bool_]:
72+
"""Upsample raw mask logits to image size and threshold at zero.
73+
74+
Approximates ``PostProcess.forward``: bilinear resize followed by ``> 0``.
75+
Uses Pillow's bilinear resampling rather than ``F.interpolate`` (no PyTorch
76+
dependency at inference time); border pixels may differ slightly due to
77+
distinct half-pixel conventions.
78+
79+
Args:
80+
mask_logits: Raw mask logits of shape ``(K, Hm, Wm)``.
81+
out_size: Target ``(width, height)`` in pixels.
82+
83+
Returns:
84+
Boolean mask array of shape ``(K, height, width)``.
85+
86+
Raises:
87+
ValueError: If *mask_logits* is not rank-3.
88+
"""
89+
if mask_logits.ndim != 3:
90+
raise ValueError(
91+
f"_decode_masks expects rank-3 (K, Hm, Wm); got shape {mask_logits.shape}. "
92+
"This usually means the rank-4 mask-output heuristic in _run_inference matched the wrong tensor."
93+
)
94+
width, height = out_size
95+
out = np.empty((mask_logits.shape[0], height, width), dtype=np.bool_)
96+
for i, logit_map in enumerate(mask_logits):
97+
mask_img = PILImage.fromarray(logit_map.astype(np.float32), mode="F")
98+
resized = mask_img.resize((width, height), _PIL_BILINEAR)
99+
out[i] = np.asarray(resized) > 0.0
100+
return out
101+
102+
67103
def _run_inference(
68104
interp: Any,
69105
image_path: str | Path,
@@ -75,6 +111,8 @@ def _run_inference(
75111
normalises the image with ImageNet statistics, invokes the model, then
76112
decodes the ``dets`` / ``labels`` output tensors into a
77113
:class:`supervision.Detections` object with pixel-space ``xyxy`` boxes.
114+
For segmentation exports the ``masks`` output is also decoded into
115+
``Detections.mask``.
78116
79117
Args:
80118
interp: Allocated TFLite interpreter returned by ``_create_interpreter``.
@@ -83,8 +121,8 @@ def _run_inference(
83121
84122
Returns:
85123
A tuple of ``(detections, pil_img)`` where ``detections`` contains
86-
pixel-space ``xyxy`` boxes and ``pil_img`` is the original PIL image
87-
at its original resolution.
124+
pixel-space ``xyxy`` boxes (and ``mask`` for segmentation models) and
125+
``pil_img`` is the original PIL image at its original resolution.
88126
"""
89127
inp_det = interp.get_input_details()
90128
out_det = interp.get_output_details()
@@ -119,11 +157,11 @@ def _run_inference(
119157
boxes_idx = next((i for i, od in enumerate(out_det) if "dets" in str(od.get("name", ""))), None)
120158
logits_idx = next((i for i, od in enumerate(out_det) if "labels" in str(od.get("name", ""))), None)
121159
if boxes_idx is None or logits_idx is None:
122-
# onnx2tf sometimes renames outputs to generic "Identity", "Identity_N" instead
123-
# of preserving the original ONNX node names. Fall back to shape-based
124-
# matching for the detection outputs only: boxes (*, 4) and logits
125-
# (*, num_classes+1). Segmentation exports may include additional outputs
126-
# such as masks; unnamed extra outputs are not resolved by this fallback.
160+
# onnx2tf sometimes renames outputs to generic "Identity", "Identity_N"
161+
# instead of preserving the original ONNX node names. Fall back to
162+
# shape-based matching: boxes are the rank-3 tensor with last dim 4,
163+
# logits the rank-3 tensor with last dim != 4. A rank-4 mask output,
164+
# if present, is matched separately below.
127165
logger.debug(
128166
"Name-based output matching failed (available: %s). Falling back to shape-based matching.",
129167
available_output_names,
@@ -177,4 +215,22 @@ def _run_inference(
177215
xyxy = np.stack([cx - bw / 2, cy - bh / 2, cx + bw / 2, cy + bh / 2], axis=1)
178216
xyxy *= np.array([ow, oh, ow, oh], dtype=np.float32)
179217

180-
return sv.Detections(xyxy=xyxy, confidence=scores[keep], class_id=cls[keep].astype(int)), pil_img
218+
# Segmentation exports add a rank-4 mask output; decode it when present.
219+
mask_idx = next((i for i, od in enumerate(out_det) if "masks" in str(od.get("name", ""))), None)
220+
if mask_idx is None:
221+
rank4_candidates = [i for i, od in enumerate(out_det) if len(od["shape"]) == 4]
222+
if len(rank4_candidates) == 1:
223+
mask_idx = rank4_candidates[0]
224+
elif len(rank4_candidates) >= 2:
225+
logger.warning(
226+
"Ambiguous rank-4 outputs (%d candidates); skipping mask decode. "
227+
"Name your mask output to contain 'masks' to disambiguate.",
228+
len(rank4_candidates),
229+
)
230+
masks = None
231+
if mask_idx is not None and keep.any():
232+
raw_masks = interp.get_tensor(out_det[mask_idx]["index"])[0] # (Q, Hm, Wm)
233+
masks = _decode_masks(raw_masks[keep], (ow, oh))
234+
235+
detections = sv.Detections(xyxy=xyxy, confidence=scores[keep], class_id=cls[keep].astype(int), mask=masks)
236+
return detections, pil_img

tests/export/test_tflite_inference.py

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Covers:
1010
* ``_create_interpreter()`` — interpreter loading with tflite_runtime / tensorflow fallback
1111
* ``_run_inference()`` — image preprocessing, invocation, and detection decoding
12+
* ``_decode_masks()`` — segmentation mask upsampling and thresholding
1213
"""
1314

1415
from __future__ import annotations
@@ -22,7 +23,7 @@
2223
import supervision as sv
2324
from PIL import Image as PILImage
2425

25-
from rfdetr.export._tflite.inference import _create_interpreter, _run_inference
26+
from rfdetr.export._tflite.inference import _create_interpreter, _decode_masks, _run_inference
2627

2728
# ---------------------------------------------------------------------------
2829
# Shared helpers / factories
@@ -439,3 +440,147 @@ def _get_tensor(index: int) -> np.ndarray:
439440
dets, _ = _run_inference(interp, rgb_image, threshold=0.3)
440441
assert isinstance(dets, sv.Detections)
441442
assert len(dets) >= 1
443+
444+
445+
# ---------------------------------------------------------------------------
446+
# TestMaskDecoding
447+
# ---------------------------------------------------------------------------
448+
449+
450+
class TestMaskDecoding:
451+
"""Tests for ``_decode_masks()`` and mask decoding in ``_run_inference()``."""
452+
453+
@pytest.fixture()
454+
def rgb_image(self, tmp_path: Path) -> Path:
455+
"""Write a small RGB JPEG to a temp file and return its path."""
456+
p = tmp_path / "image.jpg"
457+
_save_rgb_image(p)
458+
return p
459+
460+
def test_decode_masks_shape_and_dtype(self) -> None:
461+
"""Output shape is (K, height, width) from out_size=(width, height); dtype is bool."""
462+
out = _decode_masks(np.zeros((3, 10, 10), dtype=np.float32), (40, 20))
463+
assert out.shape == (3, 20, 40)
464+
assert out.dtype == bool
465+
466+
def test_decode_masks_thresholds_at_zero(self) -> None:
467+
"""Positive logits decode to True, negative logits to False."""
468+
logits = np.stack(
469+
[
470+
np.full((8, 8), 5.0, dtype=np.float32),
471+
np.full((8, 8), -5.0, dtype=np.float32),
472+
]
473+
)
474+
out = _decode_masks(logits, (16, 16))
475+
assert out[0].all()
476+
assert not out[1].any()
477+
478+
def test_decode_masks_empty_input(self) -> None:
479+
"""Zero masks in yields a (0, height, width) array, not an error."""
480+
out = _decode_masks(np.zeros((0, 10, 10), dtype=np.float32), (32, 32))
481+
assert out.shape == (0, 32, 32)
482+
483+
def test_run_inference_decodes_masks_for_seg_model(self, rgb_image: Path) -> None:
484+
"""A 3-output segmentation export populates Detections.mask at image size."""
485+
boxes = _make_boxes()
486+
logits = _make_logits(high_conf_idx=0)
487+
masks = np.full((1, 10, 28, 28), -10.0, dtype=np.float32)
488+
masks[0, 0] = 10.0 # query 0 (the kept detection) gets an all-positive mask
489+
490+
def _get_tensor(index: int) -> np.ndarray:
491+
return {1: boxes, 2: logits, 3: masks}[index]
492+
493+
interp = mock.MagicMock()
494+
interp.get_input_details.return_value = [{"shape": _INPUT_SHAPE, "index": 0, "dtype": np.float32}]
495+
interp.get_output_details.return_value = [
496+
{"shape": [1, 10, 4], "name": "Identity_0", "index": 1},
497+
{"shape": [1, 10, 82], "name": "Identity_1", "index": 2},
498+
{"shape": [1, 10, 28, 28], "name": "Identity_2", "index": 3},
499+
]
500+
interp.get_tensor.side_effect = _get_tensor
501+
502+
dets, img = _run_inference(interp, rgb_image, threshold=0.3)
503+
assert dets.mask is not None
504+
assert dets.mask.shape == (len(dets), img.height, img.width)
505+
assert dets.mask.dtype == bool
506+
assert dets.mask[0].all() # query 0's all-positive logits decode to a full mask
507+
508+
def test_run_inference_no_mask_for_detection_model(self, rgb_image: Path) -> None:
509+
"""A 2-output detection export leaves Detections.mask as None."""
510+
interp = _make_interp(logits=_make_logits(high_conf_idx=0))
511+
dets, _ = _run_inference(interp, rgb_image, threshold=0.3)
512+
assert dets.mask is None
513+
514+
def test_run_inference_name_based_mask_detection(self, rgb_image: Path) -> None:
515+
"""Output named 'masks:0' exercises the name-based path and sets Detections.mask."""
516+
boxes = _make_boxes()
517+
logits = _make_logits(high_conf_idx=0)
518+
masks = np.full((1, 10, 28, 28), 10.0, dtype=np.float32)
519+
520+
def _get_tensor(index: int) -> np.ndarray:
521+
return {1: boxes, 2: logits, 3: masks}[index]
522+
523+
interp = mock.MagicMock()
524+
interp.get_input_details.return_value = [{"shape": _INPUT_SHAPE, "index": 0, "dtype": np.float32}]
525+
interp.get_output_details.return_value = [
526+
{"shape": [1, 10, 4], "name": "serving_default_dets:0", "index": 1},
527+
{"shape": [1, 10, 82], "name": "serving_default_labels:0", "index": 2},
528+
{"shape": [1, 10, 28, 28], "name": "serving_default_masks:0", "index": 3},
529+
]
530+
interp.get_tensor.side_effect = _get_tensor
531+
532+
dets, _ = _run_inference(interp, rgb_image, threshold=0.3)
533+
assert dets.mask is not None
534+
535+
def test_run_inference_seg_model_no_detections_returns_none_mask(self, rgb_image: Path) -> None:
536+
"""Seg model with all scores below threshold returns mask=None (keep.any() is False)."""
537+
boxes = _make_boxes()
538+
logits = _make_logits(high_conf_idx=None) # all scores near zero, below threshold
539+
masks = np.full((1, 10, 28, 28), 10.0, dtype=np.float32)
540+
541+
def _get_tensor(index: int) -> np.ndarray:
542+
return {1: boxes, 2: logits, 3: masks}[index]
543+
544+
interp = mock.MagicMock()
545+
interp.get_input_details.return_value = [{"shape": _INPUT_SHAPE, "index": 0, "dtype": np.float32}]
546+
interp.get_output_details.return_value = [
547+
{"shape": [1, 10, 4], "name": "Identity_0", "index": 1},
548+
{"shape": [1, 10, 82], "name": "Identity_1", "index": 2},
549+
{"shape": [1, 10, 28, 28], "name": "Identity_2", "index": 3},
550+
]
551+
interp.get_tensor.side_effect = _get_tensor
552+
553+
dets, _ = _run_inference(interp, rgb_image, threshold=0.3)
554+
assert len(dets) == 0
555+
assert dets.mask is None
556+
557+
def test_decode_masks_raises_on_wrong_rank(self) -> None:
558+
"""_decode_masks raises ValueError when input is not rank-3."""
559+
with pytest.raises(ValueError, match="rank-3"):
560+
_decode_masks(np.zeros((10, 28, 28, 1), dtype=np.float32), (56, 56))
561+
562+
def test_decode_masks_exact_zero_logit_decodes_to_false(self) -> None:
563+
"""Logit exactly 0.0 is not > 0.0 and decodes to False (strict threshold)."""
564+
zero_logits = np.zeros((1, 8, 8), dtype=np.float32)
565+
out = _decode_masks(zero_logits, (16, 16))
566+
assert not out.any()
567+
568+
def test_decode_masks_non_square_logit_input(self) -> None:
569+
"""Non-square logit map (K, Hm, Wm) with Hm != Wm resizes to the correct output shape."""
570+
logits = np.full((3, 7, 14), 5.0, dtype=np.float32)
571+
out = _decode_masks(logits, (56, 28)) # out_size=(width=56, height=28)
572+
assert out.shape == (3, 28, 56)
573+
assert out.all() # all-positive logits → all True
574+
575+
def test_decode_masks_parity_positive_negative_regions(self) -> None:
576+
"""Positive/negative logit regions map correctly after bilinear upsample + threshold.
577+
578+
Uses high-magnitude logits (±10) so no ambiguity near the boundary; verifies
579+
the core _decode_masks contract matches the >0 PostProcess.forward equivalent.
580+
"""
581+
logits = np.full((1, 14, 14), -10.0, dtype=np.float32)
582+
logits[0, :7, :] = 10.0 # top half strongly positive, bottom half strongly negative
583+
out = _decode_masks(logits, (28, 28))
584+
# Interior rows well away from the half-way boundary
585+
assert out[0, 1:6, :].all() # top rows → all True
586+
assert not out[0, 15:27, :].any() # bottom rows → all False

0 commit comments

Comments
 (0)