实战项目:工业缺陷检测

引言

你有没有想过,手机屏幕、汽车零件、药片这些产品在出厂前,是怎么被逐一检查有没有瑕疵的?靠人眼?一天看几万个零件,难免会眼花。更何况,有些缺陷比头发丝还细。

工业缺陷检测就是用机器代替人眼,24小时不间断地去“找茬”。它背后靠的不是魔法,而是计算机视觉和深度学习。这篇文章就是给开发者准备的“找茬指南”,重点解决一个特别常见的难题:正常样本多到用不完,缺陷样本却凤毛麟角。这种场景在行业里叫异常检测

我们会从传统方法聊起,一直讲到卷积自编码器这种深度学习方案,还给出可以直接跑的 PyTorch 代码和部署思路。无论你是刚入门,还是正打算把模型跑在产线上,希望这篇文章都能帮到你。

📂 所属阶段:第二阶段 — 深度学习视觉基础(CNN 篇)
🔗 相关章节:实战项目一:智能人脸考勤系统 · 实战项目三:自动驾驶感知


1. 工业缺陷检测是什么

简单说,就是用摄像头拍下产品图像,然后用算法自动判断它是否合格。相比人工检测,机器不会累、标准统一,而且能留下完整的数据记录。

1.1 为什么工厂需要它?

  • 质量更稳:人眼会有疲劳、情绪、经验差异,机器不会。算法一旦定好,对所有产品一视同仁。
  • 省钱省时间:前期投入一笔开发成本,后期可以省下大量质检人力;更重要的是,能在早期截住次品,避免后续返工或者召回造成的更大损失。
  • 守护安全与品牌:一颗有缺陷的螺丝可能毁掉一台设备,一批外观不良的产品可能砸了多年经营的招牌。

1.2 常见的缺陷长什么样?

在生产线上,缺陷五花八门,大致可以归为这几类:

  • 表面缺陷:划痕、凹坑、污渍、裂纹、颜色不均。比如手机玻璃上的细微划痕。
  • 结构/尺寸缺陷:尺寸超差、变形、缺料、内部气泡、材料分层。
  • 装配缺陷:零件装错了位置、螺丝漏拧、焊点虚焊。

检测时最头疼的几个问题

  1. 光线和角度总在变:车间环境并不像实验室那样稳定,一天内亮度变化、产品摆放角度稍有不同,图像就可能差别很大。
  2. 缺陷太小、太像背景:比如带有木纹的地板上有一个小裂纹,肉眼都费劲,机器更难分辨。
  3. “次品”太少了:一条稳定产线上,可能 99.9% 都是正常品,缺陷样本一个月也攒不了几个。这让传统“看大量缺陷样本学习”的分类模型很难训练。
  4. 速度与精度要兼得:高速流水线上一秒钟过好几个产品,一张图的处理时间只有几毫秒,还得保持低误报率。

这些难点决定了我们不能用普通的“猫狗分类”思路去解决问题,而要采用异常检测的方法。


2. 核心技术:异常检测怎么玩?

核心逻辑其实不复杂:让模型只学习“正常产品长什么样”,然后凡是看起来不太对的,都归为异常。 就像我们只见过完整的苹果,突然出现一个带虫眼的,立刻就能意识到它不正常。

根据手头数据量的多少和产品图像的复杂程度,通常有两种做法。

2.1 传统方法:当数据不多、纹理简单时

如果你的产品纹理很规则(比如纯色金属片、简单图案的布料),而且正常样本只有几百或几千张,用传统的机器学习反而更省力,甚至不需要 GPU。

怎么做?

整体分三步:

  1. 提取特征:把每张图像转换成一串能描述“正常模样”的数字,比如纹理均匀度、颜色分布、边缘梯度等。
  2. 降维与标准化:把特征维度压一压,去掉冗余信息,并缩放到同一尺度。
  3. 训练异常检测模型:用“孤立森林”或“单类支持向量机”这类算法,在正常样本的特征空间中圈出一块“正常区域”。新来的样本如果掉到区域外,就是异常。

下面是一个可直接使用的 Python 实现,用到 scikit-imagescikit-learn

import numpy as np
import cv2
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from skimage.feature import local_binary_pattern

class TraditionalAnomalyDetector:
    def __init__(self, method='isolation_forest', contamination=0.05, pca_dim=50):
        self.method = method
        self.contamination = contamination
        self.scaler = StandardScaler()
        self.pca = PCA(n_components=pca_dim)
        # 默认使用孤立森林,你也可以换成 OneClassSVM
        self.model = IsolationForest(contamination=contamination, random_state=42) \
            if method == 'isolation_forest' else None
    
    def extract_features(self, images):
        """从每张图中提取 LBP 纹理、梯度统计和灰度统计特征"""
        features = []
        for img in images:
            # 统一转成单通道灰度图
            gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) if len(img.shape)==3 else img
            
            # 1. LBP 纹理特征:刻画局部纹理模式
            lbp = local_binary_pattern(gray, 24, 3, method='uniform')
            lbp_hist, _ = np.histogram(lbp.ravel(), bins=26, density=True)
            
            # 2. 梯度统计特征:捕捉边缘和轮廓变化
            grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
            grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
            mag = np.sqrt(grad_x**2 + grad_y**2)
            grad_stats = [np.mean(mag), np.std(mag), np.percentile(mag, 25), np.percentile(mag, 75)]
            
            # 3. 灰度统计特征:亮暗分布
            gray_stats = [np.mean(gray), np.std(gray), np.median(gray)]
            
            features.append(np.concatenate([lbp_hist, grad_stats, gray_stats]))
        return np.array(features)
    
    def fit(self, normal_images):
        """只用正常样本来训练"""
        feats = self.extract_features(normal_images)
        feats = self.scaler.fit_transform(feats)    # 标准化
        feats = self.pca.fit_transform(feats)       # 降维
        self.model.fit(feats)                       # 学习正常边界
    
    def predict(self, images):
        """返回 (预测标签, 异常分数),标签 -1 表示异常,1 表示正常"""
        feats = self.extract_features(images)
        feats = self.scaler.transform(feats)
        feats = self.pca.transform(feats)
        # 孤立森林中,-1 是异常,分数越低越可能异常
        return self.model.predict(feats), -self.model.score_samples(feats)

💡 什么时候用它?
当你只有 CPU、数据量在几千张以内、产品纹理不复杂时,这套传统方案的性价比极高。你甚至可以不写深度学习框架的依赖,直接用打包工具部署。

2.2 深度学习方法:当数据多、纹理复杂时

如果产品表面本身就有复杂花纹(比如布面、印刷包装),传统人工设计的特征就很难覆盖所有“正常”变化。这时就需要卷积自编码器(Convolutional Autoencoder, CAE)出场了。

它为什么好用?

自编码器就像一个“记忆大师”,由两部分组成:

  • 编码器:把输入图像一步一步压缩成一个浓缩的特征向量(好比只记关键信息)。
  • 解码器:再从这个浓缩特征把图像“还原”出来。

如果只用正常品去训练它,那么解码器就只学会了“如何重建正常产品的样子”。当一个有缺陷的样本给进来时,解码器依旧会努力把它恢复成正常样貌,结果重建出来的图像和原图差别很大。我们只需要计算这个差别(重建误差),就可以判断是否存在缺陷。

用 PyTorch 从头实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as T

class ConvAutoencoder(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        # 编码器:连续下采样,提取高层特征
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 4, 2, 1), nn.ReLU(inplace=True),     # 224→112
            nn.Conv2d(32, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),  # 112→56
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),# 56→28
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),# 28→14
            nn.Conv2d(256, 512, 4, 2, 1), nn.ReLU(inplace=True)             # 14→7
        )
        # 解码器:连续上采样,恢复图像
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),  # 7→14
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),  # 14→28
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),    # 28→56
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(inplace=True),                         # 56→112
            nn.ConvTranspose2d(32, in_channels, 4, 2, 1), nn.Sigmoid()                          # 112→224,输出像素值[0,1]
        )
    
    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

class DeepAnomalyDetector:
    def __init__(self, img_size=(3,224,224), lr=1e-4, device=None):
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = ConvAutoencoder(img_size[0]).to(self.device)
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.transform = T.Compose([
            T.ToPILImage(), T.Resize(img_size[1:]), T.ToTensor()
        ])
        self.threshold = None
    
    def fit(self, normal_imgs, epochs=50, batch_size=32):
        """用正常样本训练自编码器"""
        processed = [self.transform(img) for img in normal_imgs]
        loader = DataLoader(TensorDataset(torch.stack(processed)), batch_size=batch_size, shuffle=True)
        
        self.model.train()
        for epoch in range(epochs):
            total_loss = 0.0
            for (data,) in loader:
                data = data.to(self.device)
                recon = self.model(data)
                loss = self.criterion(recon, data)         # 让重建结果与原始输入尽量一致
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item() * len(data)
            if (epoch+1) % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Avg Loss: {total_loss/len(normal_imgs):.6f}")
    
    def set_threshold(self, normal_imgs, percentile=95):
        """基于正常样本的重建误差,设定异常判定的阈值"""
        errors = []
        self.model.eval()
        with torch.no_grad():
            for img in normal_imgs:
                data = self.transform(img).unsqueeze(0).to(self.device)
                recon = self.model(data)
                errors.append(self.criterion(recon, data).item())
        self.threshold = np.percentile(errors, percentile)
        print(f"Set threshold to {self.threshold:.6f}")
    
    def predict(self, imgs):
        """返回每张图片的检测结果:是否缺陷、重建误差、置信度"""
        if self.threshold is None:
            raise ValueError("请先调用 set_threshold 方法设定阈值!")
        results = []
        self.model.eval()
        with torch.no_grad():
            for img in imgs:
                data = self.transform(img).unsqueeze(0).to(self.device)
                recon = self.model(data)
                err = self.criterion(recon, data).item()
                results.append({
                    "is_defective": err > self.threshold,
                    "recon_error": err,
                    "confidence": min(err / self.threshold, 2.0)   # 越高代表越“确信”是缺陷
                })
        return results

📌 训练小贴士:训练时只用正常图像,但最好留一小部分正常样本来设定阈值(如 95 分位数),这样误判率可控。


3. 从实验到产线:部署与优化

代码能跑只是第一步,真正上到流水线,还要考虑速度、稳定性和可维护性。

3.1 加速部署的实用技巧

  1. 量化模型:PyTorch 的 torch.quantization 可以把模型从 32 位浮点数压缩到 8 位整数,体积缩小大约 4 倍,推理速度提升 2~3 倍,非常适合边缘设备。
  2. 缩小输入尺寸:如果 224×224 的细节对你分辨缺陷来说太“豪华”了,可以试试 128×128 甚至 96×96,速度会有肉眼可见的提升。
  3. 脱离 PyTorch 运行时:用 torch.onnx.export 导出 ONNX 模型,再用 OpenCV DNN 或 ONNX Runtime 加载推理。这样一来,部署包只需要 OpenCV,干净利落。

3.2 一个简易的工业部署骨架

下面这个类演示了如何把训练好的检测器用在实际场景里:单张检测、数据库记录、实时视频流处理一应俱全。

import cv2
import sqlite3
from datetime import datetime

class IndustrialDefectSystem:
    def __init__(self, detector, db_path="defect_results.db"):
        self.detector = detector
        self.db_path = db_path
        self._init_db()
    
    def _init_db(self):
        """创建存放检测结果的数据库"""
        with sqlite3.connect(self.db_path) as conn:
            conn.execute('''
                CREATE TABLE IF NOT EXISTS results (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    is_defective BOOLEAN,
                    recon_error REAL,
                    confidence REAL
                )
            ''')
    
    def detect_single(self, img):
        """检测单张图片并持久化结果"""
        res = self.detector.predict([img])[0]
        with sqlite3.connect(self.db_path) as conn:
            conn.execute('''
                INSERT INTO results (is_defective, recon_error, confidence)
                VALUES (?, ?, ?)
            ''', (res["is_defective"], res["recon_error"], res["confidence"]))
        return res
    
    def detect_video(self, source=0, show=True):
        """实时视频流检测(source=0 表示默认摄像头)"""
        cap = cv2.VideoCapture(source)
        if not cap.isOpened():
            raise ValueError("无法打开视频源")
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            res = self.detect_single(frame)
            # 可视化:缺陷框红色,正常框绿色
            color = (0,0,255) if res["is_defective"] else (0,255,0)
            text = f"DEFECT: {res['confidence']:.2f}" if res["is_defective"] else "OK"
            cv2.putText(frame, text, (30,30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
            if show:
                cv2.imshow("Defect Detection", frame)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
        cap.release()
        cv2.destroyAllWindows()

4. 学习路线与总结

学习建议

  • 先走通传统方案:不急着上深度学习,用孤立森林那套代码先看看你的产品图像能否被简单特征区分开。它可以帮助你快速理解异常检测的本质。
  • CAE 是工业常用船票:当传统方法撑不住时,卷积自编码器通常是深度学习方案里最稳定、最易落地的起点。
  • 数据质量比模型花哨更重要:尽可能收集不同光照、不同批次、不同角度的正常样本;少量缺陷样本只用来帮你验证阈值,不需要参与训练。
  • 把鲁棒性当成第一指标:部署前一定要用强光、弱光、部分遮挡等极端图片反复测试。生产环境可不会跟你客气。

总结

工业缺陷检测的核心,其实就一句话:用最少的缺陷样本,解决最实际的生产问题。文中介绍的传统方法和卷积自编码器已经能覆盖大部分常见场景,而且实现和部署成本相对可控。搞定这些基础,你就有了快速出活的底气,之后再去挑战 VAE、GAN、PatchCore 等更高级的方法也不迟。

工业部署中,**稳定性远远比“实验室准确率”重要**。不要受到各种 SOTA 模型的诱惑,优先选择推理速度快、结构简单、容易排查问题的方案。