Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/ci-integrations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ jobs:
python-version: ${{ matrix.python-version }}
activate-environment: true

- name: 🚀 Install Packages (plus extras)
- name: 🚀 Install Packages (plus + visual extras)
timeout-minutes: 5
# Install PyTorch CPU-only first (UV_TORCH_BACKEND=cpu works with 'uv pip')
run: uv pip install -e .[plus]
# [visual] pulls in supervision, required by the predict() call in
# tests/try_instantiate_all_models.py (supervision is optional since #1074).
run: uv pip install -e .[plus,visual]

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, quoted it. The install line now uses uv pip install -e ".[plus,visual]", matching the style already used in ci-tests-cpu.yml (".[train,cli,visual,kornia]").


- name: 🔎 Validate model instantiation and downloads
run: python tests/try_instantiate_all_models.py
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ dependencies = [
"tqdm", # progress bars during weight download
"transformers>=5.1.0,<6.0.0", # DINOv2 backbone loading
"pydantic>=2.0,<3", # ModelConfig / TrainConfig validation
"supervision", # inference output (Detections, Masks)
"pyDeprecate>=0.6,<0.8", # deprecation warnings for legacy APIs; >=0.6 for deprecated_class
"scipy", # models/matcher.py linear_sum_assignment, on the core import path
]

[project.optional-dependencies]
Expand All @@ -55,17 +55,18 @@ train = [
"torchmetrics[detection]>=1.2",
"faster-coco-eval>=1.7.2",
"pycocotools",
"scipy",
"albumentations>=1.4.24,<3.0.0",
"roboflow",
"rf100vl",
"supervision", # dataset/grid visualization (datasets.save_grids, datasets.synthetic)
]
onnx = [
"onnx>=1.16.0,<2.0",
"onnxsim<0.6.0", # TODO: onnxsim 0.6.0+ hangs on install
"onnx_graphsurgeon",
"onnxruntime",
"polygraphy",
"supervision", # inference output (Detections) from export._onnx.inference
]
trt = [
"pycuda",
Expand All @@ -79,6 +80,7 @@ tflite = [
"onnx>=1.20.0,<2.0",
"tf-keras>=2.16.0",
"tensorflow>=2.16.0",
"supervision", # inference output (Detections) from export._tflite.inference
]
kornia = [
"kornia>=0.7,<1", # GPU-side augmentation via on_after_batch_transfer
Expand All @@ -94,6 +96,7 @@ visual = [
"matplotlib",
"pandas",
"seaborn",
"supervision", # annotators / Detections for visualize.data and dataset grids
]
cli = ["jsonargparse[signatures]>=4.27.7"]
plus = ["rfdetr_plus>=1.0.1, <2.0.0"]
Expand Down
12 changes: 10 additions & 2 deletions src/rfdetr/datasets/save_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------

from __future__ import annotations

from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any

import matplotlib.pyplot as plt
import numpy as np
import supervision as sv
import torch
import torchvision.transforms as T # noqa: N812
from matplotlib.axes import Axes
from torch.utils.data import DataLoader

from rfdetr.util.box_ops import box_cxcywh_to_xyxy
from rfdetr.util.logger import get_logger
from rfdetr.utilities.optional_imports import import_supervision

if TYPE_CHECKING:
import supervision as sv

logger = get_logger()

Expand Down Expand Up @@ -46,6 +51,7 @@ def save_grid(self) -> None:
Each grid is a 3x3 JPEG containing up to 9 images from a single batch, with bounding boxes and class labels
drawn on top.
"""
sv = import_supervision()
inv_normalize = T.Normalize(
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
Expand Down Expand Up @@ -103,6 +109,8 @@ def _annotate_and_plot(
"""
from PIL import Image as PILImage

sv = import_supervision()

resized_size = single_target["size"]
if isinstance(resized_size, torch.Tensor):
resized_size = resized_size.detach().cpu()
Expand Down
22 changes: 15 additions & 7 deletions src/rfdetr/datasets/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,26 @@
#
"""Synthetic dataset generation with COCO formatting."""

from __future__ import annotations

import json
import logging
import math
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Tuple, Union

import cv2
import numpy as np
import supervision as sv
from tqdm.auto import tqdm
from typing_extensions import Literal

from rfdetr.utilities.optional_imports import import_supervision

if TYPE_CHECKING:
import supervision as sv

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -112,8 +118,8 @@ def _normalize_split_ratios(split_ratios: SplitRatiosType) -> Dict[str, float]:

# Available shapes for synthetic dataset generation
SYNTHETIC_SHAPES = ["square", "triangle", "circle"]
# Available colors for synthetic dataset generation (RGB format)
SYNTHETIC_COLORS = {"red": sv.Color.RED, "green": sv.Color.GREEN, "blue": sv.Color.BLUE}
# Available colors for synthetic dataset generation (resolved to sv.Color at use time)
SYNTHETIC_COLORS = ["red", "green", "blue"]
Comment on lines +121 to +122


def draw_synthetic_shape(
Expand All @@ -134,6 +140,7 @@ def draw_synthetic_shape(
Tuple of ``(image_with_shape, polygon)`` where ``polygon`` is a flat list ``[x1, y1, x2, y2, …]`` suitable for
the COCO ``segmentation`` field. Returns an empty polygon list for unknown shape names.
"""
sv = import_supervision()
cx, cy = center
half_size = size // 2

Expand Down Expand Up @@ -218,8 +225,9 @@ def generate_synthetic_sample(
``detections`` is an :class:`sv.Detections` instance whose ``data["polygons"]`` field contains one flat ``[x1,
y1, x2, y2, …]`` polygon list per detection, matching the geometry returned by :func:`draw_synthetic_shape`.
"""
sv = import_supervision()
img = np.ones((img_size, img_size, 3), dtype=np.uint8) * 128
color_names = list(SYNTHETIC_COLORS.keys())
color_names = list(SYNTHETIC_COLORS)
num_objects = random.randint(min_objects, max_objects)

xyxys = []
Expand All @@ -231,7 +239,7 @@ def generate_synthetic_sample(
for _ in range(num_objects):
shape = random.choice(SYNTHETIC_SHAPES)
color_name = random.choice(color_names)
color = SYNTHETIC_COLORS[color_name]
color = getattr(sv.Color, color_name.upper())

if class_mode == "shape":
category_id = SYNTHETIC_SHAPES.index(shape)
Expand Down Expand Up @@ -457,7 +465,7 @@ def generate_coco_dataset(
if class_mode == "shape":
classes = SYNTHETIC_SHAPES
else:
classes = list(SYNTHETIC_COLORS.keys())
classes = list(SYNTHETIC_COLORS)

# Shuffle indices for splits
all_indices = list(range(num_images))
Expand Down
3 changes: 2 additions & 1 deletion src/rfdetr/datasets/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
make_coco_transforms,
make_coco_transforms_square_div_64,
)
from rfdetr.utilities.optional_imports import import_supervision

REQUIRED_YOLO_YAML_FILES = ["data.yaml", "data.yml"]
REQUIRED_SPLIT_DIRS = ["train", "valid"]
Expand Down Expand Up @@ -161,7 +162,7 @@ class _LazyYoloSample:

def to_detections(self) -> "sv.Detections":
"""Materialize the current sample as a supervision ``Detections`` object."""
import supervision as sv
sv = import_supervision()

if len(self.class_id) == 0:
return sv.Detections.empty()
Expand Down
4 changes: 3 additions & 1 deletion src/rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from rfdetr.utilities.decorators import deprecated
from rfdetr.utilities.distributed import is_main_process
from rfdetr.utilities.logger import get_logger
from rfdetr.utilities.optional_imports import import_supervision

try:
torch.set_float32_matmul_precision("high")
Expand Down Expand Up @@ -1264,8 +1265,9 @@ class IDs. The ``data`` dict of each :class:`~supervision.Detections` object con
if either dimension does not support the ``__index__`` protocol (e.g. ``float``) or is a ``bool``, if
either dimension is zero or negative, if either dimension is not divisible by ``patch_size *
num_windows``, or if ``patch_size`` is not a positive integer.
ImportError: If the optional ``supervision`` package is not installed.
"""
import supervision as sv
sv = import_supervision()

_ensure_model_on_device(self.model)

Expand Down
8 changes: 6 additions & 2 deletions src/rfdetr/export/_onnx/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
from __future__ import annotations

from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np
import supervision as sv
from PIL import Image as PILImage

from rfdetr.utilities.logger import get_logger
from rfdetr.utilities.optional_imports import import_supervision

if TYPE_CHECKING:
import supervision as sv

logger = get_logger()

Expand Down Expand Up @@ -206,4 +209,5 @@ def _run_inference(
xyxy = np.stack([cx - bw / 2, cy - bh / 2, cx + bw / 2, cy + bh / 2], axis=1)
xyxy *= np.array([ow, oh, ow, oh], dtype=np.float32)

sv = import_supervision()
return sv.Detections(xyxy=xyxy, confidence=scores[keep], class_id=cls[keep].astype(int)), pil_img
8 changes: 6 additions & 2 deletions src/rfdetr/export/_tflite/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
import contextlib
import importlib
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np
import supervision as sv
from numpy.typing import NDArray
from PIL import Image as PILImage

from rfdetr.utilities.logger import get_logger
from rfdetr.utilities.optional_imports import import_supervision

if TYPE_CHECKING:
import supervision as sv

logger = get_logger()

Expand Down Expand Up @@ -249,5 +252,6 @@ def _run_inference(
raw_masks = interp.get_tensor(out_det[mask_idx]["index"])[0] # (Q, Hm, Wm)
masks = _decode_masks(raw_masks[keep], (ow, oh))

sv = import_supervision()
detections = sv.Detections(xyxy=xyxy, confidence=scores[keep], class_id=cls[keep].astype(int), mask=masks)
return detections, pil_img
36 changes: 36 additions & 0 deletions src/rfdetr/utilities/optional_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# ------------------------------------------------------------------------
# RF-DETR
# Copyright (c) 2025 Roboflow. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
"""Helpers for importing optional third-party dependencies with friendly errors."""

from types import ModuleType
from typing import cast


def import_supervision() -> ModuleType:
"""Import the optional ``supervision`` package, raising a friendly hint if it is missing.

``supervision`` is an optional dependency: it is required for the ``Detections`` return type of
inference helpers and for annotation/visualization utilities, but is not installed by the core
``rfdetr`` package. This helper defers the import to call time so the rest of the package remains
usable without it, and turns a bare ``ModuleNotFoundError`` into an actionable installation hint.

Returns:
The imported ``supervision`` module.

Raises:
ImportError: If ``supervision`` is not installed.
"""
try:
import supervision as sv
except ModuleNotFoundError as exc:
if exc.name != "supervision":
raise
raise ImportError(
"This feature requires the 'supervision' package. Install it with "
"`pip install supervision` (also bundled in the rfdetr[onnx], rfdetr[tflite], "
"rfdetr[train], and rfdetr[visual] extras)."
) from exc
Comment thread
Borda marked this conversation as resolved.
return cast(ModuleType, sv)
Comment on lines +12 to +36

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'd prefer to leave this un-memoized:

  1. Python already caches the module in sys.modules, so after the first call import supervision is just a dict lookup + name binding (microseconds). The helper is also called once per predict() / _run_inference(), not per-sample, so there's no hot-loop cost in practice.
  2. functools.lru_cache would cache the successful module object and break tests/utilities/test_optional_imports.py::test_raises_with_install_hint_when_missing, which relies on monkeypatch.setitem(sys.modules, "supervision", None) to force the missing-package path after a prior successful import. Working around that (a cache_clear() in the test) adds complexity for a non-issue.

Happy to revisit if profiling ever shows this on a hot path.

3 changes: 2 additions & 1 deletion src/rfdetr/visualize/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from PIL import Image

from rfdetr.utilities.logger import get_logger
from rfdetr.utilities.optional_imports import import_supervision

logger = get_logger()

Expand All @@ -31,7 +32,7 @@ def save_gt_predictions_visualization(
Boxes are labeled with class ID and confidence (for predictions). For predictions with known IoU, the IoU value is
also shown.
"""
import supervision as sv
sv = import_supervision()

save_dir.mkdir(exist_ok=True)

Expand Down
3 changes: 2 additions & 1 deletion tests/datasets/test_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import numpy as np
import pytest
import supervision as sv

from rfdetr.datasets.synthetic import (
DEFAULT_SPLIT_RATIOS,
Expand All @@ -21,6 +20,8 @@
generate_synthetic_sample,
)

sv = pytest.importorskip("supervision")


class TestCalculateBoundaryOverlap:
@pytest.mark.parametrize(
Expand Down
3 changes: 2 additions & 1 deletion tests/export/test_tflite_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

import numpy as np
import pytest
import supervision as sv
from PIL import Image as PILImage

from rfdetr.export._tflite.inference import _create_interpreter, _decode_masks, _run_inference

sv = pytest.importorskip("supervision")

# ---------------------------------------------------------------------------
# Shared helpers / factories
# ---------------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
import numpy as np
import PIL.Image
import pytest
import supervision as sv
import torch

from rfdetr import RFDETRNano, RFDETRSegNano
from rfdetr.detr import RFDETR

sv = pytest.importorskip("supervision")

_HTTP_IMAGE_URL = "http://images.cocodataset.org/val2017/000000397133.jpg"
_HTTP_HOST = "images.cocodataset.org"
_HTTP_PORT = 80
Expand Down
29 changes: 29 additions & 0 deletions tests/utilities/test_optional_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# ------------------------------------------------------------------------
# RF-DETR
# Copyright (c) 2025 Roboflow. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
"""Tests for optional-dependency import helpers."""

import sys

import pytest

from rfdetr.utilities.optional_imports import import_supervision


class TestImportSupervision:
"""Tests for ``import_supervision()``."""

def test_returns_module_when_installed(self) -> None:
"""When ``supervision`` is installed, the helper returns the module."""
sv = pytest.importorskip("supervision")
assert import_supervision() is sv

def test_raises_with_install_hint_when_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""When ``supervision`` cannot be imported, a friendly ImportError with an install hint is raised."""
# Setting an entry to None makes the import machinery raise ImportError for that name.
monkeypatch.setitem(sys.modules, "supervision", None)

with pytest.raises(ImportError, match="pip install supervision"):
import_supervision()
Loading