Coverage for tests/tests_advertrain/test_training/test_adversarial_training.py: 97%
33 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
1import pytest
2import torch
3from torch.nn import Linear, Module
4from torch.optim import SGD
6from robustML.advertrain.training.adversarial_training import \
7 AdversarialTraining
10# Define a simple model with trainable parameters
11class SimpleModel(Module):
12 def __init__(self):
13 super(SimpleModel, self).__init__()
14 self.linear = Linear(in_features=10, out_features=2)
16 def forward(self, x):
17 return self.linear(x)
20@pytest.fixture
21def mock_model():
22 return SimpleModel()
25@pytest.fixture
26def mock_optimizer(mock_model):
27 return SGD(mock_model.parameters(), lr=0.001)
30@pytest.fixture
31def mock_loss_func():
32 return torch.nn.CrossEntropyLoss()
35@pytest.fixture
36def mock_device():
37 return torch.device("cpu")
40@pytest.fixture
41def mock_epsilon():
42 return 0.1
45def test_initialization(mock_model, mock_optimizer, mock_loss_func, mock_device, mock_epsilon):
46 training = AdversarialTraining(mock_model, mock_optimizer, mock_loss_func, mock_device, mock_epsilon)
47 assert training.model is mock_model
48 assert training.optimizer is mock_optimizer
49 assert training.loss_func is mock_loss_func
50 assert training.device == mock_device
51 assert training.epsilon == mock_epsilon