Coverage for tests/tests_advertrain/test_training/test_autoattack_training.py: 97%
29 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
1import pytest
2import torch
3from torch.nn import Linear, Module
4from torch.optim import SGD
6from robustML.advertrain.dependencies.autoattack import APGDAttack
7from robustML.advertrain.training.autoattack_training import AutoAttackTraining
10class SimpleModel(Module):
11 def __init__(self):
12 super(SimpleModel, self).__init__()
13 self.linear = Linear(10, 2)
15 def forward(self, x):
16 return self.linear(x)
19@pytest.fixture
20def simple_model():
21 return SimpleModel()
24@pytest.fixture
25def simple_optimizer(simple_model):
26 return SGD(simple_model.parameters(), lr=0.01)
29@pytest.fixture
30def simple_loss_func():
31 return torch.nn.CrossEntropyLoss()
34@pytest.fixture
35def simple_device():
36 return torch.device('cpu')
39def test_auto_attack_training_initialization(simple_model, simple_optimizer, simple_loss_func, simple_device):
40 training = AutoAttackTraining(simple_model, simple_optimizer, simple_loss_func, simple_device, 'ce', 0.1)
42 assert training.epsilon == 0.1
43 assert training.apgd_loss == 'ce'
44 assert isinstance(training.apgd, APGDAttack)