Coverage for tests/tests_advertrain/test_transforms.py: 100%

26 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-10 08:11 +0000

1from torchvision.transforms import (ColorJitter, Compose, Pad, RandomAffine, 

2 RandomApply, RandomHorizontalFlip, 

3 RandomVerticalFlip, Resize, ToTensor) 

4 

5from robustML.advertrain.transforms import DataTransformations 

6 

7 

8def test_initialization(): 

9 dt = DataTransformations(train_prob=0.7) 

10 assert dt.train_prob == 0.7 

11 

12 

13def test_get_train_transforms(): 

14 dt = DataTransformations(train_prob=0.5) 

15 train_transforms = dt.get_train_transforms() 

16 

17 assert isinstance(train_transforms, Compose) 

18 assert isinstance(train_transforms.transforms[0], Pad) 

19 assert isinstance(train_transforms.transforms[1], RandomHorizontalFlip) 

20 assert isinstance(train_transforms.transforms[2], RandomVerticalFlip) 

21 assert isinstance(train_transforms.transforms[3], RandomApply) 

22 assert isinstance(train_transforms.transforms[3].transforms[0], RandomAffine) 

23 assert isinstance(train_transforms.transforms[3].transforms[1], RandomAffine) 

24 assert isinstance(train_transforms.transforms[3].transforms[2], ColorJitter) 

25 assert train_transforms.transforms[3].p == 0.5 

26 assert isinstance(train_transforms.transforms[4], Resize) 

27 assert isinstance(train_transforms.transforms[5], ToTensor) 

28 

29 

30def test_get_test_transforms(): 

31 dt = DataTransformations() 

32 test_transforms = dt.get_test_transforms() 

33 

34 assert isinstance(test_transforms, Compose) 

35 assert isinstance(test_transforms.transforms[0], Pad) 

36 assert isinstance(test_transforms.transforms[1], Resize) 

37 assert isinstance(test_transforms.transforms[2], ToTensor)