构建基于 Python 的单元测试

这篇文章记录了如何基于 Python 构建一个 Web 系统的单元测试,涉及一些基本和高级用法。

测试分类

  • 单元测试:单个模块的测试
  • 集成测试:多个模块的测试
  • 功能测试:项目的功能测试

其实就是范围不同,单元测试仅是系统特定一部分的测试,功能测试是将系统作为整体进行测试,集成测试介于两者之间。

单元测试库

最常用的是 unittest 和 pytest

  • 继承 unittest 的 TestCase 类来组织单元测试
  • assert 语句用来检测是否符合预期,而 pytest 提供了一些更强大的 assert 方法
  • pytest 用来运行测试,它可以使用加强版的 assert,并且它完全支持 unittest

一个简单的单元测试

1
2
3
4
5
6
7
8
9
import unittest
from fizzbuzz import fizzbuzz


class TestFizzBuzz(unittest.TestCase):
def test_fizz(self):
for i in [3, 6, 9, 18]:
print('testing', i)
assert fizzbuzz(i) == 'Fizz'

运行:

1
2
3
4
5
6
7
8
9
(venv) $ pytest
========================== test session starts ===========================
platform darwin -- Python 3.8.6, pytest-6.1.2, py-1.9.0, pluggy-0.13.1
rootdir: /Users/miguel/testing
collected 1 items

test_fizzbuzz.py . [100%]

=========================== 1 passed in 0.03s ============================

pytest命令比较智能,它会自动识别单元测试,它假定以这样的名字:test_[something].py 或者 [something]_test.py 命名的模块都包含单元测试。同时它也会搜索子目录。

一般来说,单元测试统一放到 tests 目录下,和应用目录隔离开。

测试覆盖率

安装:pip install pytest-cov

运行 pytest --cov=fizzbuzz,可以针对 fizzbuzz 模块运行单元测试以及覆盖率

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
(venv) $ pytest --cov=fizzbuzz
========================== test session starts ===========================
platform darwin -- Python 3.8.6, pytest-6.2.2, py-1.10.0, pluggy-0.13.1
rootdir: /Users/miguel/testing
plugins: cov-2.11.1
collected 3 items

test_fizzbuzz.py ... [100%]

---------- coverage: platform darwin, python 3.8.6-final-0 -----------
Name Stmts Miss Cover
---------------------------------
fizzbuzz.py 13 4 69%
---------------------------------
TOTAL 13 4 69%


=========================== 3 passed in 0.07s ============================

还有以下参数:

  • --cov-branch 针对分支处理,有多少个分支就统计多少次
  • --cov-report=term-missing 表示以何种方式展示报告,term-missing表示在terminal上展示,并且会额外加上缺少测试覆盖的代码行数,另外一个常用选项是html 在html上展示报告,很清晰,常用。

可以添加注释 pragma: no cover 来跳过该块代码的覆盖率检测

测试参数化

使用库 parameterized: pip install parameterized

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from parameterized import parameterized

# ...

class TestLife(unittest.TestCase):
# ...

@parameterized.expand([('pattern1.txt',), ('pattern2.txt',)])
def test_load(self, pattern):
life = Life()
life.load(pattern)
assert life.survival == [2, 3]
assert life.birth == [3]
assert set(life.living_cells()) == {
(10, 10), (11, 11), (15, 10), (17, 10)}
assert life.bounding_box() == (10, 10, 17, 11)

也可以使用列表推导式:

1
2
3
4
5
class TestLife(unittest.TestCase):
# ...

@parameterized.expand([(n,) for n in range(9)])
def test_advance_cell(self, num_neighbors):

支持多参数:

1
2
3
4
5
6
7
import itertools

class TestLife(unittest.TestCase):
# ...

@parameterized.expand(itertools.product([True, False], range(9)))
def test_advance_cell(self, alive, num_neighbors):

测试异常

1
2
3
4
5
6
7
8
9
10
11
import pytest

# ...

class TestLife(unittest.TestCase):
# ...

def test_load_invalid(self):
life = Life()
with pytest.raises(RuntimeError):
life.load('pattern4.txt')

Mocking

mocking 就是劫持函数或者功能,可以控制返回值或者其他东西的一种功能。在测试中如果对某个函数已经有了详尽的测试,那么在这个函数被调用的地方,就可以用mocking功能,节约资源。

unittest 里的 mock 模块,可以使用 mock.patch_object() 来替换函数或者方法

1
2
3
4
5
6
7
8
9
from unittest import mock

class TestLife(unittest.TestCase):
# ...

@mock.patch.object(Life, '_advance_cell')
def test_advance_false(self, mock_advance_cell):
mock_advance_cell.return_value = False
# ...

测试 Web 应用

最好将测试归集到一个继承 unittest.TestCase 的类里,这样可以公用 setUp 和 tearDown 方法,会有更好的性能,以及更方便。

WSGI 和 ASGI 都有特定的规则用于服务器如何传递到应用的请求。所以我们可以注入假的请求到应用上来模拟,而不用启动真正的服务器。这些 Web 框架都有所谓的测试客户端(test clients)来帮助实现单元测试,不需要任何网络,会向应用传递假的请求。如果 Web 框架没有提供的话,WSGI 应用可以使用 Werkzeug 库,ASGI 应用可以使用 async-asgi-testclient

比如,Flask 框架可以直接使用自带的 test client:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class TestWebApp(unittest.TestCase):
def setUp(self):
self.app = create_app()
self.appctx = self.app.app_context()
self.appctx.push()
db.create_all()
self.client = self.app.test_client()

def tearDown(self):
db.drop_all()
self.appctx.pop()
self.app = None
self.appctx = None
self.client = None

Tornado 框架可以继承 HTTPTestCase or AsyncHTTPTestCase 类来实现,其中它自带了 HTTPClient 和 AsyncHTTPClient,可以直接使用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

class BaseTestCase(AsyncHTTPTestCase):
def setUp(self):
super(BaseTestCase, self).setUp()
self.db_session = test_session
self.db_session.commit()
self.cookie = SimpleCookie()

def get_app(self):
test_app = Application()
return test_app

def get_new_ioloop(self):
return IOLoop.current()

def get_url(self, path):
full_path = super(BaseTestCase, self).get_url('/api/v1{}'.format(path))
return full_path

def _update_cookies(self, headers):
try:
cookies = escape.native_str(headers['Set-Cookie'])
self.cookie.update(SimpleCookie(cookies))
except KeyError:
return

def make_response(self, req, resp):
response = Response()
response.status_code = getattr(resp, 'code', None)
response.headers = {k: v for k, v in list(resp.headers.items())}
response.encoding = get_encoding_from_headers(response.headers)
response.raw = resp
response.reason = response.raw.reason
response._content = resp.body

if isinstance(req.url, bytes):
response.url = req.url.decode('utf-8')
else:
response.url = req.url
extract_cookies_to_jar(response.cookies, req, resp)
response.request = req
return response

def send(self, url, method='GET', data=None, json_data=None, files=None, headers=None, **kwargs):
if 'follow_redirects' not in kwargs:
kwargs['follow_redirects'] = False
request = Request(url=self.get_url(url), files=files, data=data, json=json_data)
request_data = request.prepare()
if headers is None:
headers = {}
headers.update(request_data.headers)
cookie_sting = '; '.join([f'{key}={morsel.value}' for key, morsel in self.cookie.items()])
if cookie_sting != '':
headers.update({'Cookie': cookie_sting})
resp = self.fetch(url, method=method, headers=headers, body=request_data.body, allow_nonstandard_methods=True, **kwargs)
self._update_cookies(resp.headers)
response = self.make_response(request, resp)
self.db_session.rollback()
return response

def get(self, url, **kwargs):
response = self.send(url, method='GET', **kwargs)
return response

def patch(self, url, files=None, data=None, json_data=None):
response = self.send(url, method='PATCH', files=files, data=data, json_data=json_data)
return response

def post(self, url, files=None, data=None, json_data=None, **kwargs):
response = self.send(url, method='POST', files=files, data=data, json_data=json_data, **kwargs)
return response

def put(self, url, files=None, data=None, json_data=None):
response = self.send(url, method='PUT', files=files, data=data, json_data=json_data)
return response

测试 html 内容

没必要全部 match 去做测试,而是可以检查一部分内容是否存在,比如提交按钮是否存在于 html 中,而忽略其顺序等无关信息。

1
2
3
4
5
6
7
8
9
10
11
def test_registration_form(self):
response = self.client.get('/auth/register')
assert response.status_code == 200
html = response.get_data(as_text=True)

# make sure all the fields are included
assert 'name="username"' in html
assert 'name="email"' in html
assert 'name="password"' in html
assert 'name="password2"' in html
assert 'name="submit"' in html

这样的方式也适合于其他数据量比较大的测试,只需要测试关键部分即可。

提交表单

主要问题在于 CSRF token 怎么处理,可以先发一个 GET 请求,然后拿到 token,再去提交表单,这是一种方法。另一种方法就是在测试中禁掉 CSRF 的保护。

1
2
3
4
5
6
7
def setUp(self):
self.app = create_app()
self.app.config['WTF_CSRF_ENABLED'] = False # no CSRF during tests
self.appctx = self.app.app_context()
self.appctx.push()
db.create_all()
self.client = self.app.test_client()

测试表单验证

根据表单验证失败返回的语句进行判断

1
2
3
4
5
6
7
8
9
10
def test_register_user_mismatched_passwords(self):
response = self.client.post('/auth/register', data={
'username': 'alice',
'email': 'alice@example.com',
'password': 'foo',
'password2': 'bar',
})
assert response.status_code == 200
html = response.get_data(as_text=True)
assert 'Field must be equal to password.' in html

测试需要登陆验证的页面

有以下几点:

  1. setUp 方法初始化用户
  2. login 方法
  3. 完成对应测试

Example:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

# ...
import re

class TestWebApp(unittest.TestCase):
# ...

def setUp(self):
self.app = create_app()
self.app.config['WTF_CSRF_ENABLED'] = False # no CSRF during tests
self.appctx = self.app.app_context()
self.appctx.push()
db.create_all()
self.populate_db()
self.client = self.app.test_client()

def populate_db(self):
user = User(username='susan', email='susan@example.com')
user.set_password('foo')
db.session.add(user)
db.session.commit()

def login(self):
self.client.post('/auth/login', data={
'username': 'susan',
'password': 'foo',
})

def test_write_post(self):
self.login()
response = self.client.post('/', data={'post': 'Hello, world!'},
follow_redirects=True)
assert response.status_code == 200
html = response.get_data(as_text=True)
assert 'Your post is now live!' in html
assert 'Hello, world!' in html
assert re.search(r'<span class="user_popup">\s*'
r'<a href="/user/susan">\s*'
r'susan\s*</a>\s*</span>\s*said', html) is not None

测试 API 服务器

比较简单,因为 API 接口第一涉及范围小,第二返回基本上都是 JSON,容易解析。

1
2
3
4
5
6
7
8
9
10
11
12
def test_api_register_user(self):
response = self.client.post('/api/users', json={
'username': 'bob',
'email': 'bob@example.com',
'password': 'bar'
})
assert response.status_code == 201

# make sure the user is in the database
user = User.query.filter_by(username='bob').first()
assert user is not None
assert user.email == 'bob@example.com'

参考