你是否在寻找用Cython将Python关键代码编译成C扩展的优化案例

访客 性能优化 1

本文目录导读:

  1. 案例1:数值计算优化 - 矩阵乘法
  2. 案例2:字符串处理优化 - 文本处理
  3. 案例3:循环密集型优化 - 图像处理
  4. 案例4:数据结构优化 - 链表操作
  5. 编译和测试建议

我来分享几个典型的使用Cython优化Python代码的实战案例。

案例1:数值计算优化 - 矩阵乘法

原始Python代码(慢)

# matrix.py
def matrix_multiply(A, B):
    n = len(A)
    m = len(B[0])
    p = len(B)
    result = [[0.0] * m for _ in range(n)]
    for i in range(n):
        for j in range(m):
            total = 0.0
            for k in range(p):
                total += A[i][k] * B[k][j]
            result[i][j] = total
    return result

Cython优化版本

# matrix_cy.pyx
def matrix_multiply_cy(double[:, :] A, double[:, :] B):
    cdef int n = A.shape[0]
    cdef int m = B.shape[1]
    cdef int p = A.shape[1]
    cdef double[:, :] result = np.zeros((n, m), dtype=np.float64)
    cdef int i, j, k
    cdef double total
    for i in range(n):
        for j in range(m):
            total = 0.0
            for k in range(p):
                total += A[i, k] * B[k, j]
            result[i, j] = total
    return np.asarray(result)

setup.py配置

from setuptools import setup
from Cython.Build import cythonize
import numpy
setup(
    ext_modules = cythonize("matrix_cy.pyx"),
    include_dirs=[numpy.get_include()]
)

性能对比

import time
import numpy as np
# 测试
n = 500
A = np.random.rand(n, n)
B = np.random.rand(n, n)
# Python版本
start = time.time()
python_result = matrix_multiply(A, B)
print(f"Python: {time.time() - start:.3f}s")
# Cython版本
import matrix_cy
start = time.time()
cython_result = matrix_cy.matrix_multiply_cy(A, B)
print(f"Cython: {time.time() - start:.3f}s")
# 预期结果:Cython快10-50倍

案例2:字符串处理优化 - 文本处理

原始Python代码

# text_process.py
def count_words(text_list):
    word_count = {}
    for text in text_list:
        words = text.lower().split()
        for word in words:
            word_count[word] = word_count.get(word, 0) + 1
    return word_count

Cython优化版本

# text_process_cy.pyx
from libc.string cimport strcmp
from libcpp.string cimport string
from libcpp.map cimport map
from libcpp.vector cimport vector
def count_words_cy(list text_list):
    cdef map[string, int] word_count
    cdef string word
    cdef char* c_word
    for text in text_list:
        if isinstance(text, bytes):
            # 处理字节字符串
            cdef char* c_text = text
            c_word = strtok(c_text, " ")
            while c_word != NULL:
                # 转换为小写(简化处理)
                word = c_word
                word_count[word] = word_count[word] + 1
                c_word = strtok(NULL, " ")
    # 转换为Python字典
    result = {}
    cdef map[string, int].iterator it = word_count.begin()
    while it != word_count.end():
        result[it.first.decode()] = it.second
        it += 1
    return result

案例3:循环密集型优化 - 图像处理

原始Python代码

# image_filter.py
def apply_filter(image, kernel):
    height, width = image.shape
    kernel_size = kernel.shape[0]
    padding = kernel_size // 2
    output = np.zeros_like(image)
    for i in range(padding, height - padding):
        for j in range(padding, width - padding):
            pixel_sum = 0.0
            for ki in range(kernel_size):
                for kj in range(kernel_size):
                    pixel_sum += image[i - padding + ki, j - padding + kj] * kernel[ki, kj]
            output[i, j] = min(255, max(0, pixel_sum))
    return output

Cython优化版本

# image_filter_cy.pyx
import numpy as np
cimport numpy as np
from libc.math cimport fmax, fmin
def apply_filter_cy(np.ndarray[np.uint8_t, ndim=2] image, 
                    np.ndarray[np.float64_t, ndim=2] kernel):
    cdef int height = image.shape[0]
    cdef int width = image.shape[1]
    cdef int kernel_size = kernel.shape[0]
    cdef int padding = kernel_size // 2
    cdef np.ndarray[np.uint8_t, ndim=2] output = np.zeros((height, width), dtype=np.uint8)
    cdef int i, j, ki, kj
    cdef double pixel_sum
    for i in range(padding, height - padding):
        for j in range(padding, width - padding):
            pixel_sum = 0.0
            for ki in range(kernel_size):
                for kj in range(kernel_size):
                    pixel_sum += <double>image[i - padding + ki, j - padding + kj] * \
                                kernel[ki, kj]
            output[i, j] = <np.uint8_t>fmin(255, fmax(0, pixel_sum))
    return output

案例4:数据结构优化 - 链表操作

Cython实现高效链表

# linked_list_cy.pyx
cdef struct Node:
    int value
    Node* next
cdef class LinkedList:
    cdef Node* head
    cdef Node* tail
    cdef int size
    def __cinit__(self):
        self.head = NULL
        self.tail = NULL
        self.size = 0
    cpdef void append(self, int value):
        cdef Node* new_node = <Node*>malloc(sizeof(Node))
        new_node.value = value
        new_node.next = NULL
        if self.tail == NULL:
            self.head = new_node
            self.tail = new_node
        else:
            self.tail.next = new_node
            self.tail = new_node
        self.size += 1
    cpdef int pop(self):
        if self.head == NULL:
            raise IndexError("Empty list")
        cdef Node* current = self.head
        cdef Node* previous = NULL
        while current.next != NULL:
            previous = current
            current = current.next
        cdef int value = current.value
        if previous == NULL:
            self.head = NULL
            self.tail = NULL
        else:
            previous.next = NULL
            self.tail = previous
        free(current)
        self.size -= 1
        return value
    cpdef int get(self, int index):
        if index < 0 or index >= self.size:
            raise IndexError("Index out of range")
        cdef Node* current = self.head
        cdef int i = 0
        while i < index:
            current = current.next
            i += 1
        return current.value
    def __del__(self):
        cdef Node* current = self.head
        cdef Node* next_node
        while current != NULL:
            next_node = current.next
            free(current)
            current = next_node
  1. 类型声明:为变量、函数参数声明C类型
  2. 使用cdef:定义C级别的函数和变量
  3. 内存视图:使用double[:, :]代替NumPy数组
  4. 避免Python对象:在循环中尽量使用C类型
  5. 使用libc/libcpp:利用C++标准库
  6. 编译优化:在setup.py中设置编译标志

编译和测试建议

# 安装cython
pip install cython numpy
# 编译扩展
python setup.py build_ext --inplace
# 测试
python -m pytest test_performance.py -v

这些案例展示了Cython在不同场景下的优化效果,通常可以获得10-100倍的性能提升,特别是在循环密集、数值计算和字符串处理的场景中。

标签: Cython编 译C扩展

抱歉,评论功能暂时关闭!