Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,14 @@ uv sync --all-groups
See `pyproject.toml` for complete dependency specifications:

- **Core:** PyTorch, torchvision, transformers, supervision, pydantic, pyDeprecate
- **Optional:** `[train]` (training, including peft and pycocotools), `[lora]` (LoRA fine-tuning), `[plus]` (Plus models), `[onnx]` (ONNX export), `[loggers]` (tensorboard, wandb, mlflow, clearml)
- **Optional:** `[headless]` (Linux server inference with headless OpenCV), `[train]` (training, including peft and pycocotools), `[lora]` (LoRA fine-tuning), `[plus]` (Plus models), `[onnx]` (ONNX export), `[loggers]` (tensorboard, wandb, mlflow, clearml)
- **Development:** `tests`, `docs`, `build` groups

**Important version constraints:**

- PyTorch: >=2.2.0, \<3.0.0
- Transformers: >=5.0.0, \<6.0.0
- Headless inference: `rfdetr[headless]` constrains Supervision to the headless-OpenCV-compatible line and should be installed in a clean environment so only one `cv2` provider is present

## Testing

Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ To install RF-DETR, install the `rfdetr` package in a [**Python>=3.10**](https:/
pip install rfdetr
```

For Linux server or Docker inference images that should avoid GUI OpenCV system libraries, install the headless extra in a clean environment:

```bash
pip install "rfdetr[headless]"
```

<details>
<summary>Install from source</summary>

Expand Down
14 changes: 14 additions & 0 deletions docs/learn/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ RF-DETR supports several installation methods. Choose the option which best fits
uv add rfdetr
```

=== "Headless"

For Linux server or Docker inference environments where GUI OpenCV system libraries are not available, install the headless extra in a clean environment:

```bash
pip install "rfdetr[headless]"
```

With `uv`, use:

```bash
uv pip install "rfdetr[headless]"
```

=== "Source Archive"

To install the latest development version of RF-DETR from source without cloning the full repository, run the command below.
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ dependencies = [
"tqdm", # progress bars during weight download
"transformers>=5.1.0,<6.0.0", # DINOv2 backbone loading
"pydantic>=2.0,<3", # ModelConfig / TrainConfig validation
"supervision", # inference output (Detections, Masks)
"supervision>=0.20.0", # inference output (Detections, Masks)
"pyDeprecate>=0.6,<0.8", # deprecation warnings for legacy APIs; >=0.6 for deprecated_class
]

[project.optional-dependencies]
headless = [
"opencv-python-headless", # OpenCV without GUI system libraries for Linux server images
"supervision>=0.20.0,<0.21.0", # last Supervision line that depends on opencv-python-headless
Comment thread
Borda marked this conversation as resolved.
Outdated
]
lora = [
"peft", # LoRA backbone fine-tuning
]
Expand Down
17 changes: 16 additions & 1 deletion src/rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,21 @@ def _resolve_patch_size(patch_size: int | None, model_config: object, caller: st
return patch_size


def _attach_detection_metadata(detections: Any, key: str, value: Any) -> None:
"""Attach metadata to Supervision Detections across supported versions.

Args:
detections: A Supervision Detections object.
key: Metadata key to set.
value: Metadata value to attach.
"""
metadata = getattr(detections, "metadata", None)
if metadata is None:
metadata = {}
setattr(detections, "metadata", metadata)
metadata[key] = value


def _ensure_model_on_device(model_ctx: Any) -> None:
"""Move model weights to the target device recorded in *model_ctx*.

Expand Down Expand Up @@ -1349,7 +1364,7 @@ def predict(
)

if include_source_image:
detections.metadata["source_image"] = source_images[i]
_attach_detection_metadata(detections, "source_image", source_images[i])
detections.data["source_shape"] = np.tile(np.array(orig_sizes[i], dtype=np.int64), (len(detections), 1))

# Attach class names so callers can map class_id → name without a
Expand Down
31 changes: 25 additions & 6 deletions tests/cli/test_optional_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ class TestOptionalDependencies:
"""Validate selected extras constraints in pyproject.toml."""

@staticmethod
def read_loggers_extra() -> list[str]:
"""Return the loggers optional-dependency list from pyproject.toml."""
def read_extra(name: str) -> list[str]:
"""Return an optional-dependency list from pyproject.toml."""
root = pathlib.Path(__file__).parent.parent.parent
pyproject = tomllib.loads((root / "pyproject.toml").read_text())
loggers = pyproject["project"]["optional-dependencies"].get("loggers")
assert loggers, "loggers extra not found in [project.optional-dependencies]"
return loggers
extra = pyproject["project"]["optional-dependencies"].get(name)
assert extra, f"{name} extra not found in [project.optional-dependencies]"
return extra

@staticmethod
def has_upper_bound_below_4(requirement: Requirement) -> bool:
Expand All @@ -44,14 +44,33 @@ def has_upper_bound_below_4(requirement: Requirement) -> bool:

def test_loggers_extra_pins_protobuf_below_4(self):
"""loggers extra must constrain protobuf for TensorBoard compatibility."""
requirements = [Requirement(dep) for dep in self.read_loggers_extra()]
requirements = [Requirement(dep) for dep in self.read_extra("loggers")]
protobuf_requirements = [req for req in requirements if req.name == "protobuf"]
assert protobuf_requirements, "loggers extra must include protobuf dependency"

assert any(self.has_upper_bound_below_4(req) for req in protobuf_requirements), (
"protobuf dependency must include an upper bound below 4.0.0"
)

def test_headless_extra_uses_headless_opencv(self) -> None:
"""headless extra must avoid declaring the GUI OpenCV distribution."""
requirements = [Requirement(dep) for dep in self.read_extra("headless")]
requirement_names = {req.name for req in requirements}

assert "opencv-python-headless" in requirement_names
assert "opencv-python" not in requirement_names

def test_headless_extra_constrains_supervision_before_opencv_switch(self) -> None:
"""headless extra must use the Supervision line compatible with headless OpenCV."""
requirements = [Requirement(dep) for dep in self.read_extra("headless")]
supervision_requirements = [req for req in requirements if req.name == "supervision"]
assert supervision_requirements, "headless extra must constrain supervision"

assert any(
req.specifier.contains(Version("0.20.0")) and not req.specifier.contains(Version("0.21.0"))
for req in supervision_requirements
)

@pytest.mark.parametrize(
"dep_str,expected",
[
Expand Down
11 changes: 10 additions & 1 deletion tests/models/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch

from rfdetr import RFDETRNano, RFDETRSegNano
from rfdetr.detr import RFDETR
from rfdetr.detr import RFDETR, _attach_detection_metadata

_HTTP_IMAGE_URL = "http://images.cocodataset.org/val2017/000000397133.jpg"
_HTTP_HOST = "images.cocodataset.org"
Expand Down Expand Up @@ -100,6 +100,15 @@ def test_predict_accepts_image_url() -> None:
class TestPredictSourceData:
"""Verify ``predict()`` source metadata behavior."""

def test_attach_detection_metadata_handles_legacy_detections(self) -> None:
"""Metadata attachment works when the Supervision object lacks metadata."""
detections = SimpleNamespace()
source_image = np.zeros((48, 64, 3), dtype=np.uint8)

_attach_detection_metadata(detections, "source_image", source_image)

assert detections.metadata["source_image"] is source_image

def test_source_image_included_by_default(self) -> None:
"""source_image remains included by default for API compatibility."""
img = PIL.Image.new("RGB", (64, 48), color=(128, 128, 128))
Expand Down