Testing Machine Learning Projects in Tensorflow 2.0+

When building and developing machine learning models, one of the commonly asked questions is how do I test or verify that the model’s behaviour matches its specifications.

I have seen examples of tests in various implementations on articles and open source projects. However, none of them deal with the actual question of testing model behaviour.

Its not until I encountered Jeremy Jordan Testing ML article and an implementation example from Eugene Yan Testing ML implementation that I grapsed the idea of the process.

In this article I will attempt to explain how I am testing my ML models in the tensorflow framework.

Types of tests

From the articles above, ML tests can be broadly categorized into pre-train and post-train tests.

Pre-train tests are run before the actual training process starts. Such tests would include but not limited to:

  • Checking that the model’s output shape matches the output shape of the dataset

  • Checking that a single training loop results in a decrease in loss

  • Checking that the output predictions of the model falls within a specific range. For example, if the loss to minimise is categorical_crossentropy we expect the output to sum to 1.0

  • Checking that the model can actually learn by overfitting it on the training set

  • Checking for data leak in the train / test sets

If any of the above pre-train tests were to fail, it should fail and halt the training process. This is to prevent wasting any valuable resource such as GPU in the cloud.

Post-train tests are run after the model has been trained. Such tests are usually more specific but can be broadly categorized as such:

  • Invariance tests: Testing that perturbation to an input should still yield the same output.

  • Unit-directional tests: Testing that perturbation to an input should yield the desired output.

  • Unit tests: Testing on a specific section of the dataset.

Test Structure

Given the above tests, how do we incorporate them into an existing project?

Most python projects that have tests normally use a test framework such as pytest. We can leverage it for our model tests.

In my use case, I categorized the model tests into its respective folders in the tests subdirectory within a project: pretrain, posttrain

Within each of these sub directories, I further categorized these tests based on the model behaviour I’m testing for:

 1 tests/
 2 |- pytest.ini
 3 |- pretrain/
 4   |- test_output_shape.py
 5   |- test_loss.py
 6   ...
 7 |- posttrain/
 8   |- test_rotation_invariance.py
 9   |- test_perspective_shift.py
10   ...

Now, given the above tests, how do we fit it within the pipeline of model training? This is where callbacks come into play…

Use of callbacks

One of the most important features in tensorflow is the use of callbacks, which can be injected into the model’s training or prediction pipeline when you call model.fit or model.predict.

We can define custom callbacks that invoke the test runs before and after training a model in order to run the pre-train and post-train tests.

These callbacks will be passed into the model.fit function call, which takes a callbacks keyword argument list.

The callback functions we want to hook into are:

  • on_train_begin: This is called before any model training starts

  • on_train_end: This is called after all training stops.

An example of a pre-train test custom callback can be:

 1 from tensorflow.keras.callbacks import Callback
 3 class PreTrainTest(Callback):
 4     def __init__(self, train_data, train_labels, test_data, test_labels):
 5         super(PreTrainTest, self).__init__()
 6         self.train_data = train_data
 7         self.train_labels = train_labels
 8         self.test_data = test_data
 9         self.test_labels = test_labels
11     def on_train_begin(self):
12         CustomTestRunner(
13         	"tests", 
14         	"pretrain", 
15         	self.model, 
16         	self.train_data, 
17         	self.train_labels, 
18         	self.test_data, 
19         	self.test_labels).run()

Example of invoking the callback in model code:

1 ...
3 model.fit(
4 	Xtrain, 
5 	ytrain, 
6 	validation_data=(Xtest, ytest),
7 	epochs=30,
8 	batch_size=32, 
9 	callbacks=[PreTrainTest(Xtrain, ytrain, Xtest, ytest)])

Note the use of the CustomTestRunner class. This will invoke pytest dynamically as we want to run the pre-train tests before any training begins.

An example of CustomTestRunner:

 1 class CustomTestRunner:
 2     def __init__(self, directory, kind, model, train_data, train_labels, test_data, test_labels):
 3         self.directory = directory
 4         self.kind = kind
 5         self.model = model
 6         self.train_data = train_data
 7         self.train_labels = train_labels
 8         self.test_data = test_data
 9         self.test_labels = test_labels
11     def run(self):
12         testdir = os.path.join(self.directory, self.kind)
13         sys.path.append(testdir)
14         testfiles = os.listdir(testdir)
16         for test in testfiles:
17             """For each test file,
18             we want to import it using importlib
19             so we can set the model, train data etc
20             as attributes
21             """
23             testpath = os.path.join(testdir, test)
24             mod = importlib.import_module(test.split(".")[0])
26             """Each test must be structured as a class
27             before this will work..
28             """
30             for name, x in mod.__dict__.items():
31                 # iterate through the imported module 
32                 # until we find the top level class
33                 # left as an exercise...
34                 found_class = x
36             setattr(found_class, "model", self.model)
37             setattr(found_class, "xtrain", self.train_data)
38             setattr(found_class, "ytrain", self.train_labels)
39             ...
42             status = pytest.main([testpath])
44             if status == pytest.ExitCode:
45                raise Exception("Tests run failed..")

The custom test runner takes in a test directory path, what kind of tests its running and the model and data attributes for the tests to run.

Each test case needs to be structured as a class as follows:

1 class MyPreTrainTest:
2     def test_output_shape(self):
3         """
4         here we have access to self.model and self.xtrain
5         to work on...
6         """
7         ...

The above is adopted for the post-train tests in a separate callback which uses the on_train_end function call instead.

This structure enables me to test my model as well as working within the tensorflow framework with the least disruption to the workflow.

Happy Hacking !!!