Coverage for tests/tests_advertrain/test_models.py: 100%

56 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-10 08:11 +0000

1import torch 

2import torch.nn as nn 

3 

4from robustML.advertrain.models import ConvNet, ResNet 

5 

6torch.manual_seed(0) 

7 

8 

9def test_ConvNet_initialization(): 

10 device = torch.device("cpu") 

11 model = ConvNet(device) 

12 

13 assert isinstance(model.conv1, nn.Conv2d) 

14 assert isinstance(model.conv2_1, nn.Conv2d) 

15 assert isinstance(model.conv3_1, nn.Conv2d) 

16 assert isinstance(model.conv4_1, nn.Conv2d) 

17 assert isinstance(model.pooling, nn.MaxPool2d) 

18 assert isinstance(model.activation, nn.ReLU) 

19 assert isinstance(model.linear1, nn.Linear) 

20 assert isinstance(model.linear2, nn.Linear) 

21 assert isinstance(model.linear3, nn.Linear) 

22 

23 

24def test_ConvNet_forward_pass(): 

25 device = torch.device("cpu") 

26 model = ConvNet(device) 

27 

28 dummy_input = torch.randn(1, 3, 64, 128, device=device) 

29 

30 output = model(dummy_input) 

31 

32 assert output.shape == torch.Size([1, 2]) 

33 

34 

35def test_ResNet_initialization(): 

36 device = torch.device("cpu") 

37 model = ResNet(device) 

38 

39 assert isinstance(model.conv1, nn.Conv2d) 

40 assert isinstance(model.conv1_bn, nn.BatchNorm2d) 

41 assert isinstance(model.conv2, nn.Conv2d) 

42 assert isinstance(model.conv2_bn, nn.BatchNorm2d) 

43 assert isinstance(model.conv3, nn.Conv2d) 

44 assert isinstance(model.conv3_drop, nn.Dropout2d) 

45 assert isinstance(model.conv3_bn, nn.BatchNorm2d) 

46 

47 assert isinstance(model.conv4, nn.Conv2d) 

48 assert isinstance(model.conv4_bn, nn.BatchNorm2d) 

49 assert isinstance(model.conv5, nn.Conv2d) 

50 assert isinstance(model.conv5_bn, nn.BatchNorm2d) 

51 assert isinstance(model.conv6, nn.Conv2d) 

52 assert isinstance(model.conv6_drop, nn.Dropout2d) 

53 assert isinstance(model.conv6_bn, nn.BatchNorm2d) 

54 

55 assert isinstance(model.conv7, nn.Conv2d) 

56 assert isinstance(model.conv7_bn, nn.BatchNorm2d) 

57 assert isinstance(model.conv8, nn.Conv2d) 

58 assert isinstance(model.conv8_bn, nn.BatchNorm2d) 

59 assert isinstance(model.conv9, nn.Conv2d) 

60 assert isinstance(model.conv9_drop, nn.Dropout2d) 

61 assert isinstance(model.conv9_bn, nn.BatchNorm2d) 

62 

63 assert isinstance(model.conv10, nn.Conv2d) 

64 assert isinstance(model.conv10_bn, nn.BatchNorm2d) 

65 assert isinstance(model.conv11, nn.Conv2d) 

66 assert isinstance(model.conv11_bn, nn.BatchNorm2d) 

67 assert isinstance(model.conv12, nn.Conv2d) 

68 assert isinstance(model.conv12_drop, nn.Dropout2d) 

69 assert isinstance(model.conv12_bn, nn.BatchNorm2d) 

70 

71 assert isinstance(model.fc1, nn.Linear) 

72 assert isinstance(model.fc1_bn, nn.BatchNorm1d) 

73 assert isinstance(model.fc2, nn.Linear)