Simplifying python unittests using `inspect`
Published at Apr 14, 2024
So for example say you had a class like this
class Testing:
def give_first(self, nums):
return nums[0]
def give_first_alt(self, nums):
return nums[:-1][0]
Bad
For both of these classes, you expect the same outputs for a given input. Now the straightforward way of writing tests for this would be.
from src import main
import unittest
class TestSolution(unittest.TestCase):
def test_give_first(self):
self.assertEqual(main.Testing().give_first([1, 2, 3]), 1)
self.assertEqual(main.Testing().give_first_alt([1, 2, 3]), 1)
self.assertEqual(main.Testing().give_first([4, 5, 6]), 4)
self.assertEqual(main.Testing().give_first_alt([4, 5, 6]), 4)
self.assertEqual(main.Testing().give_first([7, 8, 9]), 7)
self.assertEqual(main.Testing().give_first_alt([7, 8, 9]), 7)
if __name__ == "__main__":
unittest.main()
Which using pytest you can run with pytest tests/tests.py
. But I kinda felt like this was quite annoying, for two main reasons
- You are just plainly writing repeated code
- It’s annoying to maintain, especially as we scale up to more of the same function
Better
So remembering that you can dynamically access class methods in Python I started off with a simple approach using the inspect.getmembers
function
def test_give_first(self):
methods = []
for _, method in inspect.getmembers(
main.Testing, predicate=inspect.isfunction
): methods.append(method)
for method in methods:
self.assertEqual(method(main.Testing(), [1, 2, 3]), 1)
self.assertEqual(method(main.Testing(), [4, 5, 6]), 4)
self.assertEqual(method(main.Testing(), [7, 8, 9]), 7)
This cleaned up things a bit, but then I realized if you have multiple classes that share this pattern it would make sense to have a separate method to give you the list of methods you could easily call for different test cases. So to our TestSolution
class, we can add the following
def method_gettr(problem) -> Generator[callable, None, None]:
for _, method in inspect.getmembers(
problem, predicate=inspect.isfunction
): yield method
This simplifies our testing function down even further to the following
def test_give_first(self):
for method in TestSolution.method_gettr(main.Testing):
self.assertEqual(method(main.Testing(), [1, 2, 3]), 1)
self.assertEqual(method(main.Testing(), [4, 5, 6]), 4)
self.assertEqual(method(main.Testing(), [7, 8, 9]), 7)
Even Better
For most cases, I think it’s reasonable to stop at this point, but curiosity tends to get the better of me so I wanted to see if I could avoid having to write all those pesky self.assertEqual
repeatedly. One thing we could do is create some structure to hold the inputs and outputs them loop through those. Which gives us
def test_give_first(self):
inout = [
([1, 2, 3], 1),
([4, 5, 6], 4),
([7, 8, 9], 7)
]
for data, expected in inout:
for method in TestSolution.method_gettr(main.Testing):
self.assertEqual(method(main.Testing(), data), expected)
Fully generalized
Since we basically generalized the testing code as just a function of some inputs and outputs we can encapsulate this in another helper like so.
from src import main
import unittest
import inspect
from typing import Generator
class TestSolution(unittest.TestCase):
# gives back all user
def method_gettr(problem) -> Generator[callable, None, None]:
for _, method in inspect.getmembers(
problem, predicate=inspect.isfunction
): yield method
def data_func(self, inout, problem):
for data, expected in inout:
for method in TestSolution.method_gettr(problem):
self.assertEqual(method(problem(), data), expected)
def test_give_first_general(self):
self.test_data_func([
([1, 2, 3], 1),
([4, 5, 6], 4),
([7, 8, 9], 7)
], main.Testing)
Note
If you had some functions you wanted to run the same tests on and others you wanted to run different tests all you have to do is modify the predicate of the method_gettr
, inspect.isfunction
gives you all user-defined functions, to get only a subset of these you can use various predicates. We can pass our filter to method_gettr()
like so
def method_gettr(f_filter, problem) -> Generator[callable, None, None]:
for _, method in inspect.getmembers(
problem,
predicate=lambda func: inspect.isfunction(func) and f_filter(func)
): yield method
def data_func(self, inout, problem, f_filter):
for data, expected in inout:
for method in TestSolution.method_gettr(f_filter):
self.assertEqual(method(problem(), data), expected)
I added the following function to main.Testing
just to demonstrate
def give_last(self, nums: list[int], num: int) -> int:
return nums[-1]
Filter by name name
lambda func: func.__name__.startswith("give_first")
Filter by signature
Of note here is that as you might have already noticed you can figure by practically any aspect of the signature.
def test_give_first_general(self):
# defining our signature
t_signature = inspect.Signature(
parameters=[
inspect.Parameter(
name="self", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD
),
inspect.Parameter(
name="nums", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=list[int]
)
], return_annotation=int
)
# using the signature in the lambda
self.data_func([
([1, 2, 3], 1),
([4, 5, 6], 4),
([7, 8, 9], 7)
], main.Testing, lambda func: inspect.signature(func) == t_signature)
If you are curious t_signature
is an object of inspect.Signature
and in this case it looks as follows
...(self, nums: list[int]) -> int
Full code
For the full code of this checkout the repo