Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/learn/train/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ To disable all augmentations, pass an empty dict:
model.train(dataset_dir="path/to/dataset", aug_config={})
```

Passing `{}` also drops the training resize-and-crop branch, so images are resized

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets format it as !!!Tip

directly to the target scale without random cropping. To keep the default resize
pipeline, omit `aug_config` or pass `aug_config=None`.

---

## Memory Optimization
Expand Down
2 changes: 1 addition & 1 deletion docs/learn/train/augmentations.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model.train(
)
```

To disable augmentations: `aug_config={}`. Omitting it uses the default (horizontal flip at 50%).
To disable augmentations: `aug_config={}`. This also drops the training resize-and-crop branch, so images are resized directly to the target scale without random cropping. Omitting it (or passing `aug_config=None`) uses the default (horizontal flip at 50%) and keeps the resize-and-crop branch.

## Built-in Presets

Expand Down
2 changes: 1 addition & 1 deletion src/rfdetr/datasets/aug_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

model.train(dataset_dir="...", aug_config=AUG_CONSERVATIVE) model.train(dataset_dir="...", aug_config=AUG_AGGRESSIVE)

# Disable all augmentations
# Disable all augmentations (also drops the training resize-and-crop branch)
model.train(dataset_dir="...", aug_config={})

# Fully custom
Expand Down
43 changes: 40 additions & 3 deletions src/rfdetr/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def _build_train_resize_config(
*,
square: bool,
max_size: Optional[int] = None,
disable_augmentations: bool = False,
) -> List[Dict[str, Any]]:
Comment on lines 292 to 298
"""Build the training resize pipeline as an Albumentations config list.

Expand All @@ -316,9 +317,14 @@ def _build_train_resize_config(
optional long-side cap.
max_size: Maximum long-side size for non-square resizes. Defaults to
``1333`` when *square* is ``False``.
disable_augmentations: If ``True``, omit the resize-and-crop branch so
explicitly disabled augmentations do not still randomly crop images.
Defaults to ``False`` (preserves the original two-branch behavior).

Returns:
A single-element list containing a ``OneOf`` config entry.
A single-element list. By default the entry wraps a ``OneOf`` over both
branches; when augmentations are disabled, the entry is Option A
directly.
"""
if square:
option_a: Dict[str, Any] = {
Expand Down Expand Up @@ -364,9 +370,29 @@ def _build_train_resize_config(
}
}

if disable_augmentations:
return [option_a]

return [{"OneOf": {"transforms": [option_a, option_b]}}]


def _augmentations_disabled(aug_config: Optional[Dict[str, Any]]) -> bool:
"""Decide whether the user explicitly disabled augmentations.

``aug_config={}`` is an explicit request to disable augmentations; it also
drops the resize-and-crop branch. ``aug_config=None`` (the default) and any
non-empty config keep it.
"""
if aug_config == {}:
logger.warning(
"aug_config={} disables the training resize-and-crop branch in addition to all "
"augmentations; images will not be randomly cropped. Pass aug_config=None to keep "
"the default resize pipeline."
)
return True
return False


def make_coco_transforms(
image_set: str,
resolution: int,
Expand Down Expand Up @@ -442,7 +468,12 @@ def make_coco_transforms(
if image_set == "train":
resolved_aug_config = aug_config if aug_config is not None else AUG_CONFIG
resize_wrappers = AlbumentationsWrapper.from_config(
_build_train_resize_config(scales, square=False, max_size=1333)
_build_train_resize_config(
scales,
square=False,
max_size=1333,
disable_augmentations=_augmentations_disabled(aug_config),
)
)
Comment thread
omkar-334 marked this conversation as resolved.
Comment on lines 454 to 463
pipeline = [*resize_wrappers]
if not gpu_postprocess:
Expand Down Expand Up @@ -528,7 +559,13 @@ def make_coco_transforms_square_div_64(

if image_set == "train":
resolved_aug_config = aug_config if aug_config is not None else AUG_CONFIG
resize_wrappers = AlbumentationsWrapper.from_config(_build_train_resize_config(scales, square=True))
resize_wrappers = AlbumentationsWrapper.from_config(
_build_train_resize_config(
scales,
square=True,
disable_augmentations=_augmentations_disabled(aug_config),
)
)
pipeline = [*resize_wrappers]
if not gpu_postprocess:
aug_wrappers = AlbumentationsWrapper.from_config(resolved_aug_config)
Expand Down
55 changes: 55 additions & 0 deletions tests/datasets/test_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,3 +443,58 @@ def test_gpu_postprocess_flag(

call_kwargs = mock_transforms.call_args.kwargs if mock_transforms.call_args else mock_transforms.call_args[1]
assert call_kwargs["gpu_postprocess"] is expected_gpu_postprocess


class TestAugConfigDisablesCrop:
"""``aug_config={}`` disables the training resize-and-crop branch.

``aug_config=None`` (the default) and any non-empty config keep it.
"""

@pytest.mark.parametrize(
"aug_config,expected",
[
pytest.param(None, False, id="none_keeps_augmentations"),
pytest.param({}, True, id="empty_disables_augmentations"),
pytest.param({"HorizontalFlip": {"p": 0.5}}, False, id="nonempty_keeps_augmentations"),
],
)
def test_augmentations_disabled(self, aug_config, expected):
"""_augmentations_disabled maps aug_config to the explicit disable decision."""
from rfdetr.datasets.coco import _augmentations_disabled

assert _augmentations_disabled(aug_config) is expected

def test_empty_aug_config_warns(self):
"""Passing aug_config={} logs a warning about the dropped resize-and-crop branch."""
from unittest.mock import patch

from rfdetr.datasets.coco import _augmentations_disabled

with patch("rfdetr.datasets.coco.logger") as mock_logger:
_augmentations_disabled({})
mock_logger.warning.assert_called_once()

@pytest.mark.parametrize(
"aug_config,expected",
[
pytest.param(None, False, id="none_keeps_augmentations"),
pytest.param({}, True, id="empty_disables_augmentations"),
],
)
def test_make_coco_transforms_forwards_disable_augmentations(self, aug_config, expected):
"""make_coco_transforms passes the aug_config-derived disable decision to resize config."""
from unittest.mock import patch

from rfdetr.datasets.coco import make_coco_transforms

# Patch AlbumentationsWrapper.from_config so the test exercises only argument
# forwarding; without it the mocked empty configs trigger unrelated
# "Empty augmentation config provided" warnings.
with (
patch("rfdetr.datasets.coco._build_train_resize_config", return_value=[]) as mock_build,
patch("rfdetr.datasets.coco.AlbumentationsWrapper.from_config", return_value=[]),
):
make_coco_transforms("train", 640, aug_config=aug_config)

assert mock_build.call_args.kwargs["disable_augmentations"] is expected
15 changes: 15 additions & 0 deletions tests/datasets/test_coco_resize_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,18 @@ def test_square_option_b_unchanged(self, scales, square):
for entry in inner_transforms:
assert "RandomSizedCrop" in entry
assert entry["RandomSizedCrop"]["min_max_height"] == [384, 600]


class TestBuildTrainResizeConfigCropBranch:
"""disable_augmentations=True drops the resize-and-crop branch so only Option A runs."""

@pytest.mark.parametrize(
"square",
[pytest.param(True, id="square"), pytest.param(False, id="nonsquare")],
)
def test_disable_augmentations_drops_crop_branch(self, square):
"""No RandomCrop/RandomSizedCrop appears in either square or non-square pipelines."""
result = _build_train_resize_config([480, 640], square=square, disable_augmentations=True)
flat = str(result)
assert "RandomSizedCrop" not in flat
assert "RandomCrop" not in flat
Loading