Coverage for tests/tests_advertrain/test_models.py: 100%
56 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
1import torch
2import torch.nn as nn
4from robustML.advertrain.models import ConvNet, ResNet
6torch.manual_seed(0)
9def test_ConvNet_initialization():
10 device = torch.device("cpu")
11 model = ConvNet(device)
13 assert isinstance(model.conv1, nn.Conv2d)
14 assert isinstance(model.conv2_1, nn.Conv2d)
15 assert isinstance(model.conv3_1, nn.Conv2d)
16 assert isinstance(model.conv4_1, nn.Conv2d)
17 assert isinstance(model.pooling, nn.MaxPool2d)
18 assert isinstance(model.activation, nn.ReLU)
19 assert isinstance(model.linear1, nn.Linear)
20 assert isinstance(model.linear2, nn.Linear)
21 assert isinstance(model.linear3, nn.Linear)
24def test_ConvNet_forward_pass():
25 device = torch.device("cpu")
26 model = ConvNet(device)
28 dummy_input = torch.randn(1, 3, 64, 128, device=device)
30 output = model(dummy_input)
32 assert output.shape == torch.Size([1, 2])
35def test_ResNet_initialization():
36 device = torch.device("cpu")
37 model = ResNet(device)
39 assert isinstance(model.conv1, nn.Conv2d)
40 assert isinstance(model.conv1_bn, nn.BatchNorm2d)
41 assert isinstance(model.conv2, nn.Conv2d)
42 assert isinstance(model.conv2_bn, nn.BatchNorm2d)
43 assert isinstance(model.conv3, nn.Conv2d)
44 assert isinstance(model.conv3_drop, nn.Dropout2d)
45 assert isinstance(model.conv3_bn, nn.BatchNorm2d)
47 assert isinstance(model.conv4, nn.Conv2d)
48 assert isinstance(model.conv4_bn, nn.BatchNorm2d)
49 assert isinstance(model.conv5, nn.Conv2d)
50 assert isinstance(model.conv5_bn, nn.BatchNorm2d)
51 assert isinstance(model.conv6, nn.Conv2d)
52 assert isinstance(model.conv6_drop, nn.Dropout2d)
53 assert isinstance(model.conv6_bn, nn.BatchNorm2d)
55 assert isinstance(model.conv7, nn.Conv2d)
56 assert isinstance(model.conv7_bn, nn.BatchNorm2d)
57 assert isinstance(model.conv8, nn.Conv2d)
58 assert isinstance(model.conv8_bn, nn.BatchNorm2d)
59 assert isinstance(model.conv9, nn.Conv2d)
60 assert isinstance(model.conv9_drop, nn.Dropout2d)
61 assert isinstance(model.conv9_bn, nn.BatchNorm2d)
63 assert isinstance(model.conv10, nn.Conv2d)
64 assert isinstance(model.conv10_bn, nn.BatchNorm2d)
65 assert isinstance(model.conv11, nn.Conv2d)
66 assert isinstance(model.conv11_bn, nn.BatchNorm2d)
67 assert isinstance(model.conv12, nn.Conv2d)
68 assert isinstance(model.conv12_drop, nn.Dropout2d)
69 assert isinstance(model.conv12_bn, nn.BatchNorm2d)
71 assert isinstance(model.fc1, nn.Linear)
72 assert isinstance(model.fc1_bn, nn.BatchNorm1d)
73 assert isinstance(model.fc2, nn.Linear)