本文目录导读:
我来详细讲解 Python 中 unittest 框架的写法,从基础到进阶。
基本结构
import unittest
# 要测试的函数或类
def add(x, y):
return x + y
class TestAddFunction(unittest.TestCase): # 继承 TestCase
def test_add_positive(self): # 测试方法必须以 test_ 开头
"""测试正数相加"""
result = add(3, 5)
self.assertEqual(result, 8) # 断言:期望结果等于 8
def test_add_negative(self):
"""测试负数相加"""
result = add(-1, -2)
self.assertEqual(result, -3)
def test_add_zero(self):
"""测试加0"""
self.assertEqual(add(5, 0), 5)
if __name__ == '__main__':
unittest.main() # 运行所有测试
常用的断言方法
class TestAssertions(unittest.TestCase):
def test_assertions(self):
# 相等/不相等
self.assertEqual(2 + 2, 4)
self.assertNotEqual(2 + 2, 5)
# 布尔值
self.assertTrue(1 < 2)
self.assertFalse(1 > 2)
# 是否为 None
self.assertIsNone(None)
self.assertIsNotNone([])
# 是否包含
self.assertIn('a', 'abc')
self.assertNotIn('x', 'abc')
# 比较
self.assertGreater(5, 3)
self.assertLess(3, 5)
self.assertGreaterEqual(5, 5)
# 类型检查
self.assertIsInstance('hello', str)
self.assertNotIsInstance(123, str)
# 近似相等(浮点数)
self.assertAlmostEqual(0.1 + 0.2, 0.3, places=10)
# 异常测试
with self.assertRaises(ValueError):
int('abc')
测试类的 setUp 和 tearDown
class TestDatabase(unittest.TestCase):
def setUp(self):
"""每个测试方法执行前运行"""
print("\n设置测试环境...")
self.test_data = [1, 2, 3, 4, 5]
self.db = Database() # 假设有 Database 类
def tearDown(self):
"""每个测试方法执行后运行"""
print("清理测试环境...")
self.db.close()
@classmethod
def setUpClass(cls):
"""整个测试类开始前运行一次"""
print("设置测试类环境...")
cls.shared_resource = create_shared_resource()
@classmethod
def tearDownClass(cls):
"""整个测试类结束后运行一次"""
print("清理测试类环境...")
cls.shared_resource.cleanup()
def test_data_length(self):
self.assertEqual(len(self.test_data), 5)
def test_data_type(self):
self.assertIsInstance(self.test_data, list)
测试异常和错误
class Calculator:
def divide(self, a, b):
if b == 0:
raise ValueError("除数不能为零")
return a / b
def sqrt(self, n):
if n < 0:
raise ValueError("不能对负数开平方")
return n ** 0.5
class TestCalculator(unittest.TestCase):
def setUp(self):
self.calc = Calculator()
def test_divide_normal(self):
self.assertEqual(self.calc.divide(10, 2), 5)
def test_divide_by_zero(self):
# 方法1:使用上下文管理器
with self.assertRaises(ValueError) as context:
self.calc.divide(10, 0)
# 验证异常信息
self.assertEqual(str(context.exception), "除数不能为零")
def test_sqrt_negative(self):
# 方法2:使用 assertRaises 作为装饰器
with self.assertRaises(ValueError):
self.calc.sqrt(-1)
def test_multiple_exceptions(self):
# 测试多种异常情况
with self.assertRaises((ValueError, TypeError)):
self.calc.divide("10", 2)
Mock 和补丁 (Patches)
from unittest.mock import Mock, patch, MagicMock
class UserService:
def __init__(self):
self.db = Database()
def get_user(self, user_id):
user = self.db.query(f"SELECT * FROM users WHERE id={user_id}")
return user
class TestUserService(unittest.TestCase):
@patch('module.Database') # 替换 Database 类
def test_get_user(self, MockDatabase):
# 设置 mock
mock_db = MockDatabase.return_value
mock_db.query.return_value = {"id": 1, "name": "张三"}
service = UserService()
user = service.get_user(1)
self.assertEqual(user["name"], "张三")
mock_db.query.assert_called_once_with("SELECT * FROM users WHERE id=1")
def test_with_context_manager(self):
"""使用上下文管理器进行 mock"""
with patch('module.Database') as MockDatabase:
mock_db = MockDatabase.return_value
mock_db.query.return_value = {"id": 2, "name": "李四"}
service = UserService()
result = service.get_user(2)
self.assertEqual(result["name"], "李四")
@patch('module.requests.post')
def test_api_call(self, mock_post):
"""模拟 API 调用"""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"success": True}
mock_post.return_value = mock_response
# 这里调用你的函数,它会使用 mock 的 requests.post
result = some_api_function()
self.assertTrue(result["success"])
测试参数化
# 方法1:使用 subTest
class TestMathOperations(unittest.TestCase):
def test_multiple_cases(self):
test_cases = [
(2, 3, 6), # (a, b, expected)
(0, 5, 0),
(-2, 3, -6),
(4, -3, -12),
]
for a, b, expected in test_cases:
with self.subTest(a=a, b=b):
self.assertEqual(multiply(a, b), expected)
# 方法2:使用参数化库(需要安装)
# pip install parameterized
from parameterized import parameterized
class TestParameterized(unittest.TestCase):
@parameterized.expand([
("positive", 2, 3, 6),
("zero", 0, 5, 0),
("negative", -2, 3, -6),
])
def test_multiply(self, name, a, b, expected):
self.assertEqual(multiply(a, b), expected)
跳过测试
class TestSkipConditions(unittest.TestCase):
@unittest.skip("暂时跳过此测试")
def test_temporary_skip(self):
self.fail("不应该执行到这里")
@unittest.skipIf(sys.version_info < (3, 8), "需要 Python 3.8+")
def test_python_version(self):
# 只有在 Python 3.8+ 才执行的测试
pass
@unittest.skipUnless(sys.platform.startswith("win"), "仅 Windows 系统")
def test_windows_only(self):
# 仅 Windows 系统执行的测试
pass
@unittest.expectedFailure
def test_expected_failure(self):
# 预期会失败的测试
self.assertEqual(1 + 1, 3)
组织测试文件和运行
# 目录结构
# project/
# ├── src/
# │ └── calculator.py
# └── tests/
# ├── __init__.py
# ├── test_calculator.py
# └── test_utils.py
# 运行方式:
# 1. 运行单个测试文件
# python -m unittest tests/test_calculator.py
# 2. 发现并运行所有测试
# python -m unittest discover tests/
# 3. 运行特定测试类或方法
# python -m unittest tests.test_calculator.TestCalculator.test_add
# 4. 详细输出
# python -m unittest -v tests/test_calculator.py
# 5. 用测试套件组合多个测试
def suite():
suite = unittest.TestSuite()
suite.addTest(TestCalculator('test_add'))
suite.addTest(TestCalculator('test_divide'))
return suite
if __name__ == '__main__':
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite())
完整示例:测试一个类
import unittest
from unittest.mock import patch
class BankAccount:
def __init__(self, owner, balance=0):
self.owner = owner
self.balance = balance
self.transactions = []
def deposit(self, amount):
if amount <= 0:
raise ValueError("存款金额必须大于0")
self.balance += amount
self.transactions.append(f"存款: {amount}")
return self.balance
def withdraw(self, amount):
if amount <= 0:
raise ValueError("取款金额必须大于0")
if amount > self.balance:
raise ValueError("余额不足")
self.balance -= amount
self.transactions.append(f"取款: {amount}")
return self.balance
def get_balance(self):
return self.balance
def get_transaction_count(self):
return len(self.transactions)
class TestBankAccount(unittest.TestCase):
def setUp(self):
"""每个测试前创建新账户"""
self.account = BankAccount("张三", 1000)
def test_initial_balance(self):
"""测试初始余额"""
self.assertEqual(self.account.get_balance(), 1000)
self.assertEqual(self.account.owner, "张三")
def test_deposit(self):
"""测试存款"""
new_balance = self.account.deposit(500)
self.assertEqual(new_balance, 1500)
self.assertEqual(self.account.get_balance(), 1500)
def test_deposit_negative(self):
"""测试存款负数"""
with self.assertRaises(ValueError) as context:
self.account.deposit(-100)
self.assertIn("大于0", str(context.exception))
def test_withdraw(self):
"""测试取款"""
new_balance = self.account.withdraw(300)
self.assertEqual(new_balance, 700)
self.assertEqual(self.account.get_balance(), 700)
def test_withdraw_insufficient(self):
"""测试余额不足"""
with self.assertRaises(ValueError) as context:
self.account.withdraw(2000)
self.assertIn("不足", str(context.exception))
def test_transaction_tracking(self):
"""测试交易记录"""
self.account.deposit(500)
self.account.withdraw(200)
self.account.deposit(100)
self.assertEqual(self.account.get_transaction_count(), 3)
@patch('builtins.print')
def test_deposit_logs(self, mock_print):
"""测试存款是否打印日志(假设有打印日志的功能)"""
# deposit 方法中有 print 语句,这里可以测试
if __name__ == '__main__':
unittest.main(verbosity=2) # verbosity=2 显示更多细节
- 命名规范:测试类以
Test开头,方法以test_开头 - 单一职责:每个测试方法只测试一个功能点
- 清晰的命名:测试方法名应该描述测试的内容和期望结果
- 使用 setUp/tearDown:避免重复的初始化代码
- 隔离测试:测试之间不应相互依赖
- 测试边界条件:测试正常情况、异常情况、边界值
- 使用 Mock:隔离外部依赖(数据库、API 等)
- 保持简单:测试代码也应易于阅读和维护
这样,你就可以写出结构清晰、可维护的单元测试了!