单元测试unittest怎么写?

访客 python案例 2

本文目录导读:

  1. 基本结构
  2. 常用的断言方法
  3. 测试类的 setUp 和 tearDown
  4. 测试异常和错误
  5. Mock 和补丁 (Patches)
  6. 测试参数化
  7. 跳过测试
  8. 组织测试文件和运行
  9. 完整示例:测试一个类

我来详细讲解 Python 中 unittest 框架的写法,从基础到进阶。

基本结构

import unittest
# 要测试的函数或类
def add(x, y):
    return x + y
class TestAddFunction(unittest.TestCase):  # 继承 TestCase
    def test_add_positive(self):  # 测试方法必须以 test_ 开头
        """测试正数相加"""
        result = add(3, 5)
        self.assertEqual(result, 8)  # 断言:期望结果等于 8
    def test_add_negative(self):
        """测试负数相加"""
        result = add(-1, -2)
        self.assertEqual(result, -3)
    def test_add_zero(self):
        """测试加0"""
        self.assertEqual(add(5, 0), 5)
if __name__ == '__main__':
    unittest.main()  # 运行所有测试

常用的断言方法

class TestAssertions(unittest.TestCase):
    def test_assertions(self):
        # 相等/不相等
        self.assertEqual(2 + 2, 4)
        self.assertNotEqual(2 + 2, 5)
        # 布尔值
        self.assertTrue(1 < 2)
        self.assertFalse(1 > 2)
        # 是否为 None
        self.assertIsNone(None)
        self.assertIsNotNone([])
        # 是否包含
        self.assertIn('a', 'abc')
        self.assertNotIn('x', 'abc')
        # 比较
        self.assertGreater(5, 3)
        self.assertLess(3, 5)
        self.assertGreaterEqual(5, 5)
        # 类型检查
        self.assertIsInstance('hello', str)
        self.assertNotIsInstance(123, str)
        # 近似相等(浮点数)
        self.assertAlmostEqual(0.1 + 0.2, 0.3, places=10)
        # 异常测试
        with self.assertRaises(ValueError):
            int('abc')

测试类的 setUp 和 tearDown

class TestDatabase(unittest.TestCase):
    def setUp(self):
        """每个测试方法执行前运行"""
        print("\n设置测试环境...")
        self.test_data = [1, 2, 3, 4, 5]
        self.db = Database()  # 假设有 Database 类
    def tearDown(self):
        """每个测试方法执行后运行"""
        print("清理测试环境...")
        self.db.close()
    @classmethod
    def setUpClass(cls):
        """整个测试类开始前运行一次"""
        print("设置测试类环境...")
        cls.shared_resource = create_shared_resource()
    @classmethod
    def tearDownClass(cls):
        """整个测试类结束后运行一次"""
        print("清理测试类环境...")
        cls.shared_resource.cleanup()
    def test_data_length(self):
        self.assertEqual(len(self.test_data), 5)
    def test_data_type(self):
        self.assertIsInstance(self.test_data, list)

测试异常和错误

class Calculator:
    def divide(self, a, b):
        if b == 0:
            raise ValueError("除数不能为零")
        return a / b
    def sqrt(self, n):
        if n < 0:
            raise ValueError("不能对负数开平方")
        return n ** 0.5
class TestCalculator(unittest.TestCase):
    def setUp(self):
        self.calc = Calculator()
    def test_divide_normal(self):
        self.assertEqual(self.calc.divide(10, 2), 5)
    def test_divide_by_zero(self):
        # 方法1:使用上下文管理器
        with self.assertRaises(ValueError) as context:
            self.calc.divide(10, 0)
        # 验证异常信息
        self.assertEqual(str(context.exception), "除数不能为零")
    def test_sqrt_negative(self):
        # 方法2:使用 assertRaises 作为装饰器
        with self.assertRaises(ValueError):
            self.calc.sqrt(-1)
    def test_multiple_exceptions(self):
        # 测试多种异常情况
        with self.assertRaises((ValueError, TypeError)):
            self.calc.divide("10", 2)

Mock 和补丁 (Patches)

from unittest.mock import Mock, patch, MagicMock
class UserService:
    def __init__(self):
        self.db = Database()
    def get_user(self, user_id):
        user = self.db.query(f"SELECT * FROM users WHERE id={user_id}")
        return user
class TestUserService(unittest.TestCase):
    @patch('module.Database')  # 替换 Database 类
    def test_get_user(self, MockDatabase):
        # 设置 mock
        mock_db = MockDatabase.return_value
        mock_db.query.return_value = {"id": 1, "name": "张三"}
        service = UserService()
        user = service.get_user(1)
        self.assertEqual(user["name"], "张三")
        mock_db.query.assert_called_once_with("SELECT * FROM users WHERE id=1")
    def test_with_context_manager(self):
        """使用上下文管理器进行 mock"""
        with patch('module.Database') as MockDatabase:
            mock_db = MockDatabase.return_value
            mock_db.query.return_value = {"id": 2, "name": "李四"}
            service = UserService()
            result = service.get_user(2)
            self.assertEqual(result["name"], "李四")
    @patch('module.requests.post')
    def test_api_call(self, mock_post):
        """模拟 API 调用"""
        mock_response = Mock()
        mock_response.status_code = 200
        mock_response.json.return_value = {"success": True}
        mock_post.return_value = mock_response
        # 这里调用你的函数,它会使用 mock 的 requests.post
        result = some_api_function()
        self.assertTrue(result["success"])

测试参数化

# 方法1:使用 subTest
class TestMathOperations(unittest.TestCase):
    def test_multiple_cases(self):
        test_cases = [
            (2, 3, 6),    # (a, b, expected)
            (0, 5, 0),
            (-2, 3, -6),
            (4, -3, -12),
        ]
        for a, b, expected in test_cases:
            with self.subTest(a=a, b=b):
                self.assertEqual(multiply(a, b), expected)
# 方法2:使用参数化库(需要安装)
# pip install parameterized
from parameterized import parameterized
class TestParameterized(unittest.TestCase):
    @parameterized.expand([
        ("positive", 2, 3, 6),
        ("zero", 0, 5, 0),
        ("negative", -2, 3, -6),
    ])
    def test_multiply(self, name, a, b, expected):
        self.assertEqual(multiply(a, b), expected)

跳过测试

class TestSkipConditions(unittest.TestCase):
    @unittest.skip("暂时跳过此测试")
    def test_temporary_skip(self):
        self.fail("不应该执行到这里")
    @unittest.skipIf(sys.version_info < (3, 8), "需要 Python 3.8+")
    def test_python_version(self):
        # 只有在 Python 3.8+ 才执行的测试
        pass
    @unittest.skipUnless(sys.platform.startswith("win"), "仅 Windows 系统")
    def test_windows_only(self):
        # 仅 Windows 系统执行的测试
        pass
    @unittest.expectedFailure
    def test_expected_failure(self):
        # 预期会失败的测试
        self.assertEqual(1 + 1, 3)

组织测试文件和运行

# 目录结构
# project/
#   ├── src/
#   │   └── calculator.py
#   └── tests/
#       ├── __init__.py
#       ├── test_calculator.py
#       └── test_utils.py
# 运行方式:
# 1. 运行单个测试文件
# python -m unittest tests/test_calculator.py
# 2. 发现并运行所有测试
# python -m unittest discover tests/
# 3. 运行特定测试类或方法
# python -m unittest tests.test_calculator.TestCalculator.test_add
# 4. 详细输出
# python -m unittest -v tests/test_calculator.py
# 5. 用测试套件组合多个测试
def suite():
    suite = unittest.TestSuite()
    suite.addTest(TestCalculator('test_add'))
    suite.addTest(TestCalculator('test_divide'))
    return suite
if __name__ == '__main__':
    runner = unittest.TextTestRunner(verbosity=2)
    runner.run(suite())

完整示例:测试一个类

import unittest
from unittest.mock import patch
class BankAccount:
    def __init__(self, owner, balance=0):
        self.owner = owner
        self.balance = balance
        self.transactions = []
    def deposit(self, amount):
        if amount <= 0:
            raise ValueError("存款金额必须大于0")
        self.balance += amount
        self.transactions.append(f"存款: {amount}")
        return self.balance
    def withdraw(self, amount):
        if amount <= 0:
            raise ValueError("取款金额必须大于0")
        if amount > self.balance:
            raise ValueError("余额不足")
        self.balance -= amount
        self.transactions.append(f"取款: {amount}")
        return self.balance
    def get_balance(self):
        return self.balance
    def get_transaction_count(self):
        return len(self.transactions)
class TestBankAccount(unittest.TestCase):
    def setUp(self):
        """每个测试前创建新账户"""
        self.account = BankAccount("张三", 1000)
    def test_initial_balance(self):
        """测试初始余额"""
        self.assertEqual(self.account.get_balance(), 1000)
        self.assertEqual(self.account.owner, "张三")
    def test_deposit(self):
        """测试存款"""
        new_balance = self.account.deposit(500)
        self.assertEqual(new_balance, 1500)
        self.assertEqual(self.account.get_balance(), 1500)
    def test_deposit_negative(self):
        """测试存款负数"""
        with self.assertRaises(ValueError) as context:
            self.account.deposit(-100)
        self.assertIn("大于0", str(context.exception))
    def test_withdraw(self):
        """测试取款"""
        new_balance = self.account.withdraw(300)
        self.assertEqual(new_balance, 700)
        self.assertEqual(self.account.get_balance(), 700)
    def test_withdraw_insufficient(self):
        """测试余额不足"""
        with self.assertRaises(ValueError) as context:
            self.account.withdraw(2000)
        self.assertIn("不足", str(context.exception))
    def test_transaction_tracking(self):
        """测试交易记录"""
        self.account.deposit(500)
        self.account.withdraw(200)
        self.account.deposit(100)
        self.assertEqual(self.account.get_transaction_count(), 3)
    @patch('builtins.print')
    def test_deposit_logs(self, mock_print):
        """测试存款是否打印日志(假设有打印日志的功能)"""
        # deposit 方法中有 print 语句,这里可以测试
if __name__ == '__main__':
    unittest.main(verbosity=2)  # verbosity=2 显示更多细节
  1. 命名规范:测试类以 Test 开头,方法以 test_ 开头
  2. 单一职责:每个测试方法只测试一个功能点
  3. 清晰的命名:测试方法名应该描述测试的内容和期望结果
  4. 使用 setUp/tearDown:避免重复的初始化代码
  5. 隔离测试:测试之间不应相互依赖
  6. 测试边界条件:测试正常情况、异常情况、边界值
  7. 使用 Mock:隔离外部依赖(数据库、API 等)
  8. 保持简单:测试代码也应易于阅读和维护

这样,你就可以写出结构清晰、可维护的单元测试了!

标签: 单元测试 unittest

抱歉,评论功能暂时关闭!