Coverage for tests/tests_advertrain/test_transforms.py: 100%
26 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
1from torchvision.transforms import (ColorJitter, Compose, Pad, RandomAffine,
2 RandomApply, RandomHorizontalFlip,
3 RandomVerticalFlip, Resize, ToTensor)
5from robustML.advertrain.transforms import DataTransformations
8def test_initialization():
9 dt = DataTransformations(train_prob=0.7)
10 assert dt.train_prob == 0.7
13def test_get_train_transforms():
14 dt = DataTransformations(train_prob=0.5)
15 train_transforms = dt.get_train_transforms()
17 assert isinstance(train_transforms, Compose)
18 assert isinstance(train_transforms.transforms[0], Pad)
19 assert isinstance(train_transforms.transforms[1], RandomHorizontalFlip)
20 assert isinstance(train_transforms.transforms[2], RandomVerticalFlip)
21 assert isinstance(train_transforms.transforms[3], RandomApply)
22 assert isinstance(train_transforms.transforms[3].transforms[0], RandomAffine)
23 assert isinstance(train_transforms.transforms[3].transforms[1], RandomAffine)
24 assert isinstance(train_transforms.transforms[3].transforms[2], ColorJitter)
25 assert train_transforms.transforms[3].p == 0.5
26 assert isinstance(train_transforms.transforms[4], Resize)
27 assert isinstance(train_transforms.transforms[5], ToTensor)
30def test_get_test_transforms():
31 dt = DataTransformations()
32 test_transforms = dt.get_test_transforms()
34 assert isinstance(test_transforms, Compose)
35 assert isinstance(test_transforms.transforms[0], Pad)
36 assert isinstance(test_transforms.transforms[1], Resize)
37 assert isinstance(test_transforms.transforms[2], ToTensor)