Skip to content

Commit 8e3d27a

Browse files
Bordacodex
andcommitted
Use augmentation disable intent for crop branch
Co-authored-by: Codex <codex@openai.com>
1 parent d53a2da commit 8e3d27a

3 files changed

Lines changed: 40 additions & 34 deletions

File tree

src/rfdetr/datasets/coco.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def _build_train_resize_config(
294294
*,
295295
square: bool,
296296
max_size: Optional[int] = None,
297-
include_crop_branch: bool = True,
297+
disable_augmentations: bool = False,
298298
) -> List[Dict[str, Any]]:
299299
"""Build the training resize pipeline as an Albumentations config list.
300300
@@ -317,15 +317,14 @@ def _build_train_resize_config(
317317
optional long-side cap.
318318
max_size: Maximum long-side size for non-square resizes. Defaults to
319319
``1333`` when *square* is ``False``.
320-
include_crop_branch: If ``False``, omit the resize-and-crop branch so the
321-
pipeline always uses Option A (direct resize). Useful when objects
322-
of interest can be partially or fully cropped out of frame.
323-
Defaults to ``True`` (preserves the original two-branch behavior).
320+
disable_augmentations: If ``True``, omit the resize-and-crop branch so
321+
explicitly disabled augmentations do not still randomly crop images.
322+
Defaults to ``False`` (preserves the original two-branch behavior).
324323
325324
Returns:
326-
A single-element list. When ``include_crop_branch`` is ``True`` (default)
327-
the entry wraps a ``OneOf`` over both branches; when ``False`` the
328-
entry is Option A directly.
325+
A single-element list. By default the entry wraps a ``OneOf`` over both
326+
branches; when augmentations are disabled, the entry is Option A
327+
directly.
329328
"""
330329
if square:
331330
option_a: Dict[str, Any] = {
@@ -371,14 +370,14 @@ def _build_train_resize_config(
371370
}
372371
}
373372

374-
if not include_crop_branch:
373+
if disable_augmentations:
375374
return [option_a]
376375

377376
return [{"OneOf": {"transforms": [option_a, option_b]}}]
378377

379378

380-
def _crop_branch_enabled(aug_config: Optional[Dict[str, Any]]) -> bool:
381-
"""Decide whether the training resize pipeline keeps its resize-and-crop branch.
379+
def _augmentations_disabled(aug_config: Optional[Dict[str, Any]]) -> bool:
380+
"""Decide whether the user explicitly disabled augmentations.
382381
383382
``aug_config={}`` is an explicit request to disable augmentations; it also
384383
drops the resize-and-crop branch. ``aug_config=None`` (the default) and any
@@ -390,8 +389,8 @@ def _crop_branch_enabled(aug_config: Optional[Dict[str, Any]]) -> bool:
390389
"augmentations; images will not be randomly cropped. Pass aug_config=None to keep "
391390
"the default resize pipeline."
392391
)
393-
return False
394-
return True
392+
return True
393+
return False
395394

396395

397396
def make_coco_transforms(
@@ -468,9 +467,13 @@ def make_coco_transforms(
468467

469468
if image_set == "train":
470469
resolved_aug_config = aug_config if aug_config is not None else AUG_CONFIG
471-
include_crop_branch = _crop_branch_enabled(aug_config)
472470
resize_wrappers = AlbumentationsWrapper.from_config(
473-
_build_train_resize_config(scales, square=False, max_size=1333, include_crop_branch=include_crop_branch)
471+
_build_train_resize_config(
472+
scales,
473+
square=False,
474+
max_size=1333,
475+
disable_augmentations=_augmentations_disabled(aug_config),
476+
)
474477
)
475478
pipeline = [*resize_wrappers]
476479
if not gpu_postprocess:
@@ -556,9 +559,12 @@ def make_coco_transforms_square_div_64(
556559

557560
if image_set == "train":
558561
resolved_aug_config = aug_config if aug_config is not None else AUG_CONFIG
559-
include_crop_branch = _crop_branch_enabled(aug_config)
560562
resize_wrappers = AlbumentationsWrapper.from_config(
561-
_build_train_resize_config(scales, square=True, include_crop_branch=include_crop_branch)
563+
_build_train_resize_config(
564+
scales,
565+
square=True,
566+
disable_augmentations=_augmentations_disabled(aug_config),
567+
)
562568
)
563569
pipeline = [*resize_wrappers]
564570
if not gpu_postprocess:

tests/datasets/test_coco.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -454,36 +454,36 @@ class TestAugConfigDisablesCrop:
454454
@pytest.mark.parametrize(
455455
"aug_config,expected",
456456
[
457-
pytest.param(None, True, id="none_keeps_crop"),
458-
pytest.param({}, False, id="empty_disables_crop"),
459-
pytest.param({"HorizontalFlip": {"p": 0.5}}, True, id="nonempty_keeps_crop"),
457+
pytest.param(None, False, id="none_keeps_augmentations"),
458+
pytest.param({}, True, id="empty_disables_augmentations"),
459+
pytest.param({"HorizontalFlip": {"p": 0.5}}, False, id="nonempty_keeps_augmentations"),
460460
],
461461
)
462-
def test_crop_branch_enabled(self, aug_config, expected):
463-
"""_crop_branch_enabled maps aug_config to the crop-branch decision."""
464-
from rfdetr.datasets.coco import _crop_branch_enabled
462+
def test_augmentations_disabled(self, aug_config, expected):
463+
"""_augmentations_disabled maps aug_config to the explicit disable decision."""
464+
from rfdetr.datasets.coco import _augmentations_disabled
465465

466-
assert _crop_branch_enabled(aug_config) is expected
466+
assert _augmentations_disabled(aug_config) is expected
467467

468468
def test_empty_aug_config_warns(self):
469469
"""Passing aug_config={} logs a warning about the dropped resize-and-crop branch."""
470470
from unittest.mock import patch
471471

472-
from rfdetr.datasets.coco import _crop_branch_enabled
472+
from rfdetr.datasets.coco import _augmentations_disabled
473473

474474
with patch("rfdetr.datasets.coco.logger") as mock_logger:
475-
_crop_branch_enabled({})
475+
_augmentations_disabled({})
476476
mock_logger.warning.assert_called_once()
477477

478478
@pytest.mark.parametrize(
479479
"aug_config,expected",
480480
[
481-
pytest.param(None, True, id="none_keeps_crop"),
482-
pytest.param({}, False, id="empty_disables_crop"),
481+
pytest.param(None, False, id="none_keeps_augmentations"),
482+
pytest.param({}, True, id="empty_disables_augmentations"),
483483
],
484484
)
485-
def test_make_coco_transforms_forwards_crop_decision(self, aug_config, expected):
486-
"""make_coco_transforms passes the aug_config-derived value to _build_train_resize_config."""
485+
def test_make_coco_transforms_forwards_disable_augmentations(self, aug_config, expected):
486+
"""make_coco_transforms passes the aug_config-derived disable decision to resize config."""
487487
from unittest.mock import patch
488488

489489
from rfdetr.datasets.coco import make_coco_transforms
@@ -497,4 +497,4 @@ def test_make_coco_transforms_forwards_crop_decision(self, aug_config, expected)
497497
):
498498
make_coco_transforms("train", 640, aug_config=aug_config)
499499

500-
assert mock_build.call_args.kwargs["include_crop_branch"] is expected
500+
assert mock_build.call_args.kwargs["disable_augmentations"] is expected

tests/datasets/test_coco_resize_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,15 +254,15 @@ def test_square_option_b_unchanged(self, scales, square):
254254

255255

256256
class TestBuildTrainResizeConfigCropBranch:
257-
"""include_crop_branch=False drops the resize-and-crop branch so only Option A runs."""
257+
"""disable_augmentations=True drops the resize-and-crop branch so only Option A runs."""
258258

259259
@pytest.mark.parametrize(
260260
"square",
261261
[pytest.param(True, id="square"), pytest.param(False, id="nonsquare")],
262262
)
263-
def test_include_crop_branch_false_drops_crop_branch(self, square):
263+
def test_disable_augmentations_drops_crop_branch(self, square):
264264
"""No RandomCrop/RandomSizedCrop appears in either square or non-square pipelines."""
265-
result = _build_train_resize_config([480, 640], square=square, include_crop_branch=False)
265+
result = _build_train_resize_config([480, 640], square=square, disable_augmentations=True)
266266
flat = str(result)
267267
assert "RandomSizedCrop" not in flat
268268
assert "RandomCrop" not in flat

0 commit comments

Comments
 (0)