Python itertools模块教程

简介

很多Python开发者在处理序列拼接、排列组合、筛选分组这类需求时,第一反应是写嵌套循环、临时列表甚至复杂的条件判断——但往往写完发现既占内存,代码可读性也一般。

这时候就该请出Python标准库的「隐藏宝藏」itertools了!它提供了一系列内存友好的迭代器工具函数,专注于「处理可迭代对象的高效逻辑」,组合起来能实现很多复杂但优雅的功能,堪称「函数式工具链」的入门首选。


无限迭代器

itertools有3个专门用来生成「没有默认终止点」的迭代器,但实际使用时必须手动加终止条件,否则会陷入死循环!

1. count():等差数列无限生成器

count(start=0, step=1)会从start开始,按step(支持负数、浮点数)的步长无限输出数值,适合生成自然数、奇数/偶数序列这类规律数列。

import itertools

# 生成从3开始、步长为2的奇数序列(取前5个)
odd_with_start = itertools.count(3, 2)
for i, num in enumerate(odd_with_start):
    print(num, end=" ")
    if i == 4:
        break
# 输出: 3 5 7 9 11

2. cycle():序列无限循环器

cycle(iterable)会把传入的可迭代对象(字符串、列表、元组都可以)从头到尾无限重复,适合做循环标记、无限轮播数据源的场景。

import itertools

# 模拟红绿灯循环(红3秒绿2秒黄1秒,取两轮半后结束)
traffic_light = itertools.cycle([("红", 3), ("绿", 2), ("黄", 1)])
total_time = 0
max_time = 20

for light, sec in traffic_light:
    print(f"{light}灯亮{sec}秒")
    total_time += sec
    if total_time >= max_time:
        break

3. repeat():单元素重复器

repeat(object[, times])会重复输出指定的对象,默认无限次,指定times后会变成有限迭代器,适合初始化临时列表、测试批量处理函数。

import itertools

# 初始化一个长度为100的「空字典占位符」列表
dict_placeholders = list(itertools.repeat({}, 100))

# 测试批量打印(不指定times的话一定要加break)
test_printer = itertools.repeat("test itertools")
for i, content in enumerate(test_printer):
    print(content)
    if i == 2:
        break

有限迭代器

聊完三个「没个尽头」的生成器,接下来看看处理已有可迭代对象、快速「裁剪/拼接/筛选」的有限迭代器工具。

1. takewhile():条件前置筛选器

takewhile(predicate, iterable)会从迭代器的开头取出元素,直到谓词函数第一次返回False为止——注意和内置函数filter()的区别:filter()是筛选所有满足条件的元素,而takewhile()遇到不满足的就直接终止。

import itertools

# 从count(1)开头取「小于等于8的偶数」
# 先取自然数,再用takewhile加偶数和大小条件
natuals = itertools.count(1)
even_le8 = itertools.takewhile(lambda x: x <= 8 and x % 2 == 0, natuals)
print(list(even_le8))  # 输出: [2, 4, 6, 8]

2. chain():多个可迭代对象无缝拼接

chain(*iterables)可以把多个可迭代对象(不管类型是否完全一致)依次拼接成一个大的迭代器,不需要创建中间的「拼接列表」,节省了内存。

import itertools

# 拼接字符串、列表、range对象
all_chars = itertools.chain("AB", ["C", "D"], range(5, 8))
print(list(all_chars))  # 输出: ['A', 'B', 'C', 'D', 5, 6, 7]

3. groupby():相邻重复元素分组器

groupby(iterable, key=None)会把相邻且key值相同的元素分成一组,返回一个包含(key, 组迭代器)的迭代器——这里要注意「相邻」两个字:如果不相邻的元素key值相同,会分成不同的组,所以使用前通常要先按key排序!

import itertools

# 1. 基础用法:按相邻原元素分组
raw_str = "AAABBBCCAAA"
for key, group in itertools.groupby(raw_str):
    print(key, list(group))
"""
输出:
A ['A', 'A', 'A']
B ['B', 'B', 'B']
C ['C', 'C']
A ['A', 'A', 'A']
"""

# 2. 进阶用法:自定义key函数,且先排序保证分组正确
unsorted_str = "AaaBBbcCAAa"
# 先按大写字母排序
sorted_str = sorted(unsorted_str, key=lambda c: c.upper())
for key, group in itertools.groupby(sorted_str, key=lambda c: c.upper()):
    print(key, list(group))
"""
输出:
A ['A', 'a', 'a', 'A', 'A', 'a']
B ['B', 'B', 'b']
C ['c', 'C']
"""

组合迭代器

组合迭代器是itertools模块中「最常用也最实用」的部分,专门用来生成排列、组合、笛卡尔积这类数学组合场景的结果,不需要自己写复杂的递归或嵌套循环。

1. product():笛卡尔积生成器

product(*iterables, repeat=1)会计算多个可迭代对象的笛卡尔积,如果只有一个可迭代对象但需要「多次自身相乘」,可以用repeat参数简化写法。

import itertools

# 1. 两个列表的笛卡尔积
color_size = list(itertools.product(["红", "蓝"], ["S", "M", "L"]))
print(color_size)
# 输出: [('红', 'S'), ('红', 'M'), ('红', 'L'), ('蓝', 'S'), ('蓝', 'M'), ('蓝', 'L')]

# 2. 单个可迭代对象自身相乘3次(相当于掷3次骰子的所有可能)
dice_3 = list(itertools.product(range(1, 7), repeat=3))
print(dice_3[:5])  # 只看前5个结果
# 输出: [(1, 1, 1), (1, 1, 2), (1, 1, 3), (1, 1, 4), (1, 1, 5)]

2. permutations():无重复全排列生成器

permutations(iterable, r=None)会生成可迭代对象中长度为r的所有无重复排列(顺序不同视为不同结果),如果不指定r,默认生成全长度排列。

import itertools

# 生成ABC中长度为2的排列
perm_abc_2 = list(itertools.permutations("ABC", 2))
print(perm_abc_2)
# 输出: [('A', 'B'), ('A', 'C'), ('B', 'A'), ('B', 'C'), ('C', 'A'), ('C', 'B')]

3. combinations():无重复组合生成器

combinations(iterable, r)会生成可迭代对象中长度为r的所有无重复组合(顺序不同视为同一结果,且元素不能重复使用),必须指定r参数。

import itertools

# 生成ABC中长度为2的组合
comb_abc_2 = list(itertools.combinations("ABC", 2))
print(comb_abc_2)
# 输出: [('A', 'B'), ('A', 'C'), ('B', 'C')]

实际应用:用莱布尼茨公式近似计算圆周率

莱布尼茨公式是一个经典的「用级数近似计算π」的方法,公式逻辑很简单,用itertools组合起来实现,既优雅又内存友好。

import itertools

def approximate_pi(N):
    """
    用莱布尼茨公式近似计算π:π/4 = 1 - 1/3 + 1/5 - 1/7 + ...
    :param N: 取前N项计算
    :return: π的近似值
    """
    # 1. 用count()生成无限奇数序列
    odd_seq = itertools.count(1, 2)
    
    # 2. 用islice()(另一个常用有限迭代器,用来切片迭代器)取前N项
    first_N_odds = itertools.islice(odd_seq, N)
    
    # 3. 用生成器表达式(结合索引奇偶变号)生成级数项
    series = (4 / x * (-1)**i for i, x in enumerate(first_N_odds))
    
    # 4. 求和得到近似值
    return sum(series)

# 测试不同N的精度
print(f"取前10项:{approximate_pi(10):.10f}")      # 约3.0418396189
print(f"取前100项:{approximate_pi(100):.10f}")    # 约3.1315929036
print(f"取前10000项:{approximate_pi(10000):.10f}") # 约3.1414926536

性能考虑

itertools工具之所以高效,主要有两个原因:

  1. 所有函数都返回迭代器:不会一次性计算所有值并加载到内存,处理大数据集(比如百万级以上的序列)时优势非常明显;
  2. 底层用C实现:比自己写的Python循环快很多。

使用时还可以结合「生成器表达式」进一步优化,避免创建临时列表。


总结

itertools模块是Python处理可迭代对象的「瑞士军刀」,核心覆盖三个场景:

  • 生成无限规律序列;
  • 裁剪、拼接、筛选、分组已有序列;
  • 生成排列、组合、笛卡尔积这类数学组合结果。

学会灵活组合这些工具,能让你的代码更简洁、更高效、更有Pythonic风格!