本文目录导读:
我来分享几个典型的使用Cython优化Python代码的实战案例。
案例1:数值计算优化 - 矩阵乘法
原始Python代码(慢)
# matrix.py
def matrix_multiply(A, B):
n = len(A)
m = len(B[0])
p = len(B)
result = [[0.0] * m for _ in range(n)]
for i in range(n):
for j in range(m):
total = 0.0
for k in range(p):
total += A[i][k] * B[k][j]
result[i][j] = total
return result
Cython优化版本
# matrix_cy.pyx
def matrix_multiply_cy(double[:, :] A, double[:, :] B):
cdef int n = A.shape[0]
cdef int m = B.shape[1]
cdef int p = A.shape[1]
cdef double[:, :] result = np.zeros((n, m), dtype=np.float64)
cdef int i, j, k
cdef double total
for i in range(n):
for j in range(m):
total = 0.0
for k in range(p):
total += A[i, k] * B[k, j]
result[i, j] = total
return np.asarray(result)
setup.py配置
from setuptools import setup
from Cython.Build import cythonize
import numpy
setup(
ext_modules = cythonize("matrix_cy.pyx"),
include_dirs=[numpy.get_include()]
)
性能对比
import time
import numpy as np
# 测试
n = 500
A = np.random.rand(n, n)
B = np.random.rand(n, n)
# Python版本
start = time.time()
python_result = matrix_multiply(A, B)
print(f"Python: {time.time() - start:.3f}s")
# Cython版本
import matrix_cy
start = time.time()
cython_result = matrix_cy.matrix_multiply_cy(A, B)
print(f"Cython: {time.time() - start:.3f}s")
# 预期结果:Cython快10-50倍
案例2:字符串处理优化 - 文本处理
原始Python代码
# text_process.py
def count_words(text_list):
word_count = {}
for text in text_list:
words = text.lower().split()
for word in words:
word_count[word] = word_count.get(word, 0) + 1
return word_count
Cython优化版本
# text_process_cy.pyx
from libc.string cimport strcmp
from libcpp.string cimport string
from libcpp.map cimport map
from libcpp.vector cimport vector
def count_words_cy(list text_list):
cdef map[string, int] word_count
cdef string word
cdef char* c_word
for text in text_list:
if isinstance(text, bytes):
# 处理字节字符串
cdef char* c_text = text
c_word = strtok(c_text, " ")
while c_word != NULL:
# 转换为小写(简化处理)
word = c_word
word_count[word] = word_count[word] + 1
c_word = strtok(NULL, " ")
# 转换为Python字典
result = {}
cdef map[string, int].iterator it = word_count.begin()
while it != word_count.end():
result[it.first.decode()] = it.second
it += 1
return result
案例3:循环密集型优化 - 图像处理
原始Python代码
# image_filter.py
def apply_filter(image, kernel):
height, width = image.shape
kernel_size = kernel.shape[0]
padding = kernel_size // 2
output = np.zeros_like(image)
for i in range(padding, height - padding):
for j in range(padding, width - padding):
pixel_sum = 0.0
for ki in range(kernel_size):
for kj in range(kernel_size):
pixel_sum += image[i - padding + ki, j - padding + kj] * kernel[ki, kj]
output[i, j] = min(255, max(0, pixel_sum))
return output
Cython优化版本
# image_filter_cy.pyx
import numpy as np
cimport numpy as np
from libc.math cimport fmax, fmin
def apply_filter_cy(np.ndarray[np.uint8_t, ndim=2] image,
np.ndarray[np.float64_t, ndim=2] kernel):
cdef int height = image.shape[0]
cdef int width = image.shape[1]
cdef int kernel_size = kernel.shape[0]
cdef int padding = kernel_size // 2
cdef np.ndarray[np.uint8_t, ndim=2] output = np.zeros((height, width), dtype=np.uint8)
cdef int i, j, ki, kj
cdef double pixel_sum
for i in range(padding, height - padding):
for j in range(padding, width - padding):
pixel_sum = 0.0
for ki in range(kernel_size):
for kj in range(kernel_size):
pixel_sum += <double>image[i - padding + ki, j - padding + kj] * \
kernel[ki, kj]
output[i, j] = <np.uint8_t>fmin(255, fmax(0, pixel_sum))
return output
案例4:数据结构优化 - 链表操作
Cython实现高效链表
# linked_list_cy.pyx
cdef struct Node:
int value
Node* next
cdef class LinkedList:
cdef Node* head
cdef Node* tail
cdef int size
def __cinit__(self):
self.head = NULL
self.tail = NULL
self.size = 0
cpdef void append(self, int value):
cdef Node* new_node = <Node*>malloc(sizeof(Node))
new_node.value = value
new_node.next = NULL
if self.tail == NULL:
self.head = new_node
self.tail = new_node
else:
self.tail.next = new_node
self.tail = new_node
self.size += 1
cpdef int pop(self):
if self.head == NULL:
raise IndexError("Empty list")
cdef Node* current = self.head
cdef Node* previous = NULL
while current.next != NULL:
previous = current
current = current.next
cdef int value = current.value
if previous == NULL:
self.head = NULL
self.tail = NULL
else:
previous.next = NULL
self.tail = previous
free(current)
self.size -= 1
return value
cpdef int get(self, int index):
if index < 0 or index >= self.size:
raise IndexError("Index out of range")
cdef Node* current = self.head
cdef int i = 0
while i < index:
current = current.next
i += 1
return current.value
def __del__(self):
cdef Node* current = self.head
cdef Node* next_node
while current != NULL:
next_node = current.next
free(current)
current = next_node
- 类型声明:为变量、函数参数声明C类型
- 使用cdef:定义C级别的函数和变量
- 内存视图:使用
double[:, :]代替NumPy数组 - 避免Python对象:在循环中尽量使用C类型
- 使用libc/libcpp:利用C++标准库
- 编译优化:在setup.py中设置编译标志
编译和测试建议
# 安装cython pip install cython numpy # 编译扩展 python setup.py build_ext --inplace # 测试 python -m pytest test_performance.py -v
这些案例展示了Cython在不同场景下的优化效果,通常可以获得10-100倍的性能提升,特别是在循环密集、数值计算和字符串处理的场景中。