Coverage for tests/tests_advertrain/test_training/test_autoattack_training.py: 97%

29 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-10 08:11 +0000

1import pytest 

2import torch 

3from torch.nn import Linear, Module 

4from torch.optim import SGD 

5 

6from robustML.advertrain.dependencies.autoattack import APGDAttack 

7from robustML.advertrain.training.autoattack_training import AutoAttackTraining 

8 

9 

10class SimpleModel(Module): 

11 def __init__(self): 

12 super(SimpleModel, self).__init__() 

13 self.linear = Linear(10, 2) 

14 

15 def forward(self, x): 

16 return self.linear(x) 

17 

18 

19@pytest.fixture 

20def simple_model(): 

21 return SimpleModel() 

22 

23 

24@pytest.fixture 

25def simple_optimizer(simple_model): 

26 return SGD(simple_model.parameters(), lr=0.01) 

27 

28 

29@pytest.fixture 

30def simple_loss_func(): 

31 return torch.nn.CrossEntropyLoss() 

32 

33 

34@pytest.fixture 

35def simple_device(): 

36 return torch.device('cpu') 

37 

38 

39def test_auto_attack_training_initialization(simple_model, simple_optimizer, simple_loss_func, simple_device): 

40 training = AutoAttackTraining(simple_model, simple_optimizer, simple_loss_func, simple_device, 'ce', 0.1) 

41 

42 assert training.epsilon == 0.1 

43 assert training.apgd_loss == 'ce' 

44 assert isinstance(training.apgd, APGDAttack)