3141 字
16 分钟
CNN训练图片验证码记录
项目结构
model.py: 定义了用于验证码识别的CNN模型。setting.py: 包含了与数据集、模型参数相关的配置信息。train.py: 实现了模型的训练逻辑,包括数据预处理、训练循环、验证以及早停机制等。
核心代码
模型定义 (model.py)
import torch.nn as nnfrom setting import IMAGE_WIDTH, IMAGE_HEIGHT, MAX_CAPTCHA, ALL_CHAR_SET_LEN
class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() # 定义卷积层和全连接层 # ...
def forward(self, x): # 前向传播过程 # ...设置文件 (setting.py)
IMAGE_HEIGHT = 40IMAGE_WIDTH = 140ALL_CHAR_SET = ['0'-'9', 'A'-'Z', 'a'-'z']ALL_CHAR_SET_LEN = len(ALL_CHAR_SET)MAX_CAPTCHA = 5# 其他训练参数如学习率、批次大小等训练脚本 (train.py)
1. 数据集类
class CaptchaDataset(Dataset): # 初始化函数、__len__ 和 __getitem__ 方法 # 主要功能:加载图片、转换为灰度图、从文件名提取标签、数据增强等 # ...2. 训练模型
def train_model(): # 检查数据集完整性、选择设备(GPU/CPU) # 创建数据加载器、模型实例化、损失函数及优化器设置 # 训练循环:前向传播、计算损失、反向传播、更新权重 # 验证模型性能、早停机制、保存最佳模型 # ...3. 验证模型
def validate_model(model, test_loader, device, criterion): # 在验证集上评估模型性能 # 返回准确率和平均损失值 # ...4. 绘制训练曲线
def plot_training_curve(train_losses, train_accuracies, val_losses, val_accuracies, learning_rates): # 绘制训练和验证的损失与准确率变化曲线 # 同时绘制学习率的变化情况 # ...总结该训练过程首先通过定义一个适合验证码识别任务的CNN模型开始。接着,通过
CaptchaDataset类对数据进行预处理,并在训练过程中使用了早停机制来避免过拟合。最后,通过可视化工具展示了模型在训练过程中的表现,包括损失、准确率以及学习率的变化趋势。这个结构化的流程可以作为解决类似问题的一个模板,方便后续的调整和扩展。
完整代码
model.py
#! /usr/bin/env python# -*- coding: utf-8 -*-# @Time : 2025/11/19 20:04# @Author : afish# @File : model.pyimport torch.nn as nn
import setting
class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.Dropout(0.5), nn.ReLU(), nn.MaxPool2d(2), ) self.layer2 = nn.Sequential( nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.Dropout(0.5), nn.ReLU(), nn.MaxPool2d(2), ) self.layer3 = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.Dropout(0.5), nn.ReLU(), nn.MaxPool2d(2), ) self.fc = nn.Sequential( nn.Linear( (setting.IMAGE_WIDTH // 8) * (setting.IMAGE_HEIGHT // 8) * 64, 1024 ), nn.Dropout(0.5), nn.ReLU(), ) self.rfc = nn.Sequential( nn.Linear(1024, setting.MAX_CAPTCHA * setting.ALL_CHAR_SET_LEN), )
def forward(self, x): out = self.layer1(x) out = self.layer2(out) out = self.layer3(out) out = out.view(out.size(0), -1) out = self.fc(out) out = self.rfc(out) return outsetting.py
#! /usr/bin/env python# -*- coding: utf-8 -*-# @Time : 2025/11/19 20:14# @Author : afish# @File : setting.py# 图片尺寸设置IMAGE_HEIGHT = 40IMAGE_WIDTH = 140
# 字符集设置(数字 + 小写字母)ALL_CHAR_SET = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z' ]
ALL_CHAR_SET_LEN = len(ALL_CHAR_SET)
# 验证码长度MAX_CAPTCHA = 5
# 训练参数BATCH_SIZE = 256 * 2EPOCHS = 50LEARNING_RATE = 0.001
# 文件路径TRAIN_DATASET_PATH = "data/train"TEST_DATASET_PATH = "data/test"MODEL_SAVE_PATH = "model/captcha_model.pth"train.py
#! /usr/bin/env python# -*- coding: utf-8 -*-# @Time : 2025/11/19 20:04# @Author : afish# @File : train.py
import osimport torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import DataLoader, Datasetfrom PIL import Image, ImageFileimport numpy as npfrom model import CNNimport settingfrom torchvision import transformsimport matplotlib.pyplot as pltimport randomimport timeimport math
# 允许加载截断的图片文件ImageFile.LOAD_TRUNCATED_IMAGES = True
class EarlyStopping: """早停机制"""
def __init__(self, patience=7, verbose=True, delta=0, path='checkpoint.pt'): """ Args: patience (int): 验证集性能不再提升的epoch数,之后停止训练 verbose (bool): 是否打印早停信息 delta (float): 认为有提升的最小变化量 path (str): 模型保存路径 """ self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = float('inf') self.delta = delta self.path = path
def __call__(self, val_loss, model): score = -val_loss
if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) elif score < self.best_score + self.delta: self.counter += 1 if self.verbose: print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0
def save_checkpoint(self, val_loss, model): """保存模型检查点""" if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...') torch.save(model.state_dict(), self.path) self.val_loss_min = val_loss
class CaptchaDataset(Dataset): def __init__(self, data_path, transform=None): self.data_path = data_path self.transform = transform self.image_files = [f for f in os.listdir(data_path) if f.endswith(('.jpg', '.jpeg', '.png'))]
# 预检查图片文件,记录损坏的文件 self.valid_files = [] self.corrupted_files = []
print("检查图片文件完整性...") for img_name in self.image_files: img_path = os.path.join(self.data_path, img_name) try: # 尝试打开图片检查是否损坏 with Image.open(img_path) as img: img.verify() # 验证图片完整性 self.valid_files.append(img_name) except (IOError, SyntaxError, OSError) as e: print(f"损坏图片: {img_name} - 错误: {e}") self.corrupted_files.append(img_name)
print(f"有效图片: {len(self.valid_files)}, 损坏图片: {len(self.corrupted_files)}")
# 如果损坏图片太多,可以选择删除它们 if len(self.corrupted_files) > 0: response = input(f"发现 {len(self.corrupted_files)} 个损坏图片,是否删除? (y/n): ") if response.lower() == 'y': for corrupted_file in self.corrupted_files: os.remove(os.path.join(self.data_path, corrupted_file)) print("已删除损坏图片") # 重新获取有效文件列表 self.valid_files = [f for f in os.listdir(data_path) if f.endswith(('.jpg', '.jpeg', '.png'))] else: print("将跳过损坏图片")
def __len__(self): return len(self.valid_files)
def __getitem__(self, idx): img_name = self.valid_files[idx] img_path = os.path.join(self.data_path, img_name)
# 安全地读取图片 try: image = Image.open(img_path).convert('L') # 转为灰度图 except (IOError, SyntaxError, OSError) as e: # 如果图片读取失败,使用黑色图片替代 print(f"读取图片失败: {img_name}, 使用替代图片") image = Image.new('L', (setting.IMAGE_WIDTH, setting.IMAGE_HEIGHT), 0) # 黑色图片
# 从文件名中提取标签(格式:1005~2A2G2.jpg) label_str = img_name.split('~')[-1].split('.')[0]
# 验证标签长度 if len(label_str) != setting.MAX_CAPTCHA: print(f"警告: 标签长度不匹配: {label_str} (期望长度: {setting.MAX_CAPTCHA})")
# 将标签转换为向量形式 label = self.text2vec(label_str)
# 图片转换 if self.transform: try: image = self.transform(image) except Exception as e: print(f"图片转换失败: {img_name}, 错误: {e}") # 创建替代图片 image = torch.zeros(1, setting.IMAGE_HEIGHT, setting.IMAGE_WIDTH) else: # 默认转换:调整尺寸 -> 转为numpy -> 归一化 -> tensor try: image = image.resize((setting.IMAGE_WIDTH, setting.IMAGE_HEIGHT)) image = np.array(image) image = torch.FloatTensor(image) / 255.0 image = image.unsqueeze(0) # 增加通道维度 except Exception as e: print(f"图片处理失败: {img_name}, 错误: {e}") # 创建替代图片 image = torch.zeros(1, setting.IMAGE_HEIGHT, setting.IMAGE_WIDTH)
return image, label
def text2vec(self, text): """将文本标签转换为向量""" vector = torch.zeros(setting.MAX_CAPTCHA, setting.ALL_CHAR_SET_LEN) for i, char in enumerate(text): if i >= setting.MAX_CAPTCHA: break try: idx = setting.ALL_CHAR_SET.index(char) vector[i][idx] = 1 except ValueError: print(f"错误: 字符 '{char}' 不在字符集中") # 随机分配一个位置,避免训练中断 idx = random.randint(0, setting.ALL_CHAR_SET_LEN - 1) vector[i][idx] = 1 return vector.view(-1)
def vec2text(vec): """将向量转换回文本""" vec = vec.view(setting.MAX_CAPTCHA, -1) text = '' for i in range(setting.MAX_CAPTCHA): idx = torch.argmax(vec[i]).item() text += setting.ALL_CHAR_SET[idx] return text
def check_dataset(): """检查数据集是否存在""" if not os.path.exists(setting.TRAIN_DATASET_PATH) or len(os.listdir(setting.TRAIN_DATASET_PATH)) == 0: print("训练集不存在或为空!") print("请先运行 split_dataset.py 来分割数据集") return False return True
def train_model(): # 检查数据集 if not check_dataset(): return
# 检查设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'Using device: {device}')
# 创建数据转换 transform = transforms.Compose([ transforms.Resize((setting.IMAGE_HEIGHT, setting.IMAGE_WIDTH)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ])
# 创建数据集和数据加载器 train_dataset = CaptchaDataset(setting.TRAIN_DATASET_PATH, transform=transform) train_loader = DataLoader(train_dataset, batch_size=setting.BATCH_SIZE, shuffle=True)
print(f"训练集大小: {len(train_dataset)}")
# 检查测试集 if os.path.exists(setting.TEST_DATASET_PATH) and len(os.listdir(setting.TEST_DATASET_PATH)) > 0: test_dataset = CaptchaDataset(setting.TEST_DATASET_PATH, transform=transform) test_loader = DataLoader(test_dataset, batch_size=setting.BATCH_SIZE, shuffle=False) print(f"测试集大小: {len(test_dataset)}") else: test_loader = None print("未找到测试集,将只使用训练集")
# 创建模型 model = CNN().to(device) print("模型结构:") print(model)
# 定义损失函数和优化器 - 使用您原来的损失函数 criterion = nn.MultiLabelSoftMarginLoss() optimizer = optim.Adam(model.parameters(), lr=setting.LEARNING_RATE)
# 学习率调度器 scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
# 早停机制 early_stopping = EarlyStopping( patience=15, verbose=True, delta=0.001, path=setting.MODEL_SAVE_PATH.replace('.pth', '_best.pth') )
# 训练历史记录 train_losses = [] train_accuracies = [] val_losses = [] val_accuracies = [] learning_rates = []
# 创建模型保存目录 os.makedirs(os.path.dirname(setting.MODEL_SAVE_PATH), exist_ok=True)
best_accuracy = 0.0 start_time = time.time()
print("开始训练...") for epoch in range(setting.EPOCHS): model.train() running_loss = 0.0 correct = 0 total = 0 epoch_start_time = time.time()
for batch_idx, (images, labels) in enumerate(train_loader): images = images.to(device) labels = labels.to(device)
# 前向传播 outputs = model(images) loss = criterion(outputs, labels)
# 反向传播 optimizer.zero_grad() loss.backward()
# 梯度裁剪 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
running_loss += loss.item()
# 计算准确率 predicted = outputs.view(-1, setting.MAX_CAPTCHA, setting.ALL_CHAR_SET_LEN) labels_reshaped = labels.view(-1, setting.MAX_CAPTCHA, setting.ALL_CHAR_SET_LEN)
_, predicted_chars = torch.max(predicted, 2) _, label_chars = torch.max(labels_reshaped, 2)
correct += (predicted_chars == label_chars).all(dim=1).sum().item() total += labels.size(0)
if batch_idx % 10 == 0: accuracy = 100 * correct / total if total > 0 else 0 current_lr = optimizer.param_groups[0]['lr'] print(f'Epoch [{epoch + 1}/{setting.EPOCHS}], Batch [{batch_idx}/{len(train_loader)}], ' f'Loss: {loss.item():.4f}, Acc: {accuracy:.2f}%, LR: {current_lr:.2e}')
# 更新学习率 scheduler.step() current_lr = optimizer.param_groups[0]['lr'] learning_rates.append(current_lr)
# 计算训练准确率 train_accuracy = 100 * correct / total if total > 0 else 0 epoch_loss = running_loss / len(train_loader) if len(train_loader) > 0 else 0
train_losses.append(epoch_loss) train_accuracies.append(train_accuracy)
epoch_time = time.time() - epoch_start_time print(f'Epoch [{epoch + 1}/{setting.EPOCHS}], Time: {epoch_time:.2f}s, ' f'Loss: {epoch_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, LR: {current_lr:.2e}')
# 验证 if test_loader: val_accuracy, val_loss = validate_model(model, test_loader, device, criterion) val_accuracies.append(val_accuracy) val_losses.append(val_loss)
print(f'Validation Accuracy: {val_accuracy:.2f}%, Validation Loss: {val_loss:.4f}')
# 早停检查 early_stopping(val_loss, model) if early_stopping.early_stop: print("早停: 停止训练") break
# 保存最佳模型 if val_accuracy > best_accuracy: best_accuracy = val_accuracy torch.save(model.state_dict(), setting.MODEL_SAVE_PATH) print(f'Best model saved with accuracy: {best_accuracy:.2f}%') else: # 保存当前模型 torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': epoch_loss, }, setting.MODEL_SAVE_PATH.replace('.pth', f'_epoch{epoch + 1}.pth'))
total_time = time.time() - start_time print(f"训练完成! 总时间: {total_time:.2f}s")
# 绘制训练曲线 plot_training_curve(train_losses, train_accuracies, val_losses, val_accuracies, learning_rates)
# 保存最终模型 torch.save(model.state_dict(), setting.MODEL_SAVE_PATH.replace('.pth', '_final.pth')) print("最终模型已保存")
def validate_model(model, test_loader, device, criterion): model.eval() correct = 0 total = 0 running_loss = 0.0
with torch.no_grad(): for images, labels in test_loader: images = images.to(device) labels = labels.to(device)
outputs = model(images) loss = criterion(outputs, labels) running_loss += loss.item()
predicted = outputs.view(-1, setting.MAX_CAPTCHA, setting.ALL_CHAR_SET_LEN) labels_reshaped = labels.view(-1, setting.MAX_CAPTCHA, setting.ALL_CHAR_SET_LEN)
_, predicted_chars = torch.max(predicted, 2) _, label_chars = torch.max(labels_reshaped, 2)
correct += (predicted_chars == label_chars).all(dim=1).sum().item() total += labels.size(0)
accuracy = 100 * correct / total if total > 0 else 0 avg_loss = running_loss / len(test_loader) if len(test_loader) > 0 else 0 return accuracy, avg_loss
def plot_training_curve(train_losses, train_accuracies, val_losses, val_accuracies, learning_rates): plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1) plt.plot(train_losses, label='Train Loss') if val_losses: plt.plot(val_losses, label='Validation Loss') plt.title('Training and Validation Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.grid(True)
plt.subplot(1, 3, 2) plt.plot(train_accuracies, label='Train Accuracy') if val_accuracies: plt.plot(val_accuracies, label='Validation Accuracy') plt.title('Training and Validation Accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy (%)') plt.legend() plt.grid(True)
plt.subplot(1, 3, 3) plt.plot(learning_rates) plt.title('Learning Rate Schedule') plt.xlabel('Epoch') plt.ylabel('Learning Rate') plt.yscale('log') plt.grid(True)
plt.tight_layout() plt.savefig('training_curve.png', dpi=300, bbox_inches='tight') plt.show()
if __name__ == "__main__": train_model()pth转onnx脚本
#! /usr/bin/env python# -*- coding: utf-8 -*-# @Time : 2025/11/19 23:09# @Author : afish# @File : export_to_onnx.py# export_to_onnx.pyimport os
import numpy as npimport onnximport onnxruntime as ortimport torch
import settingfrom model import CNN
def export_model_to_onnx(): """将训练好的PyTorch模型转换为ONNX格式"""
print("🚀 开始ONNX模型转换...") print("=" * 50)
# 检查模型文件是否存在 model_path = setting.MODEL_SAVE_PATH if not os.path.exists(model_path): print(f"❌ 模型文件不存在: {model_path}") print("💡 请先确保训练完成并保存了最佳模型") return None
# 设置设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"📱 使用设备: {device}")
# 加载模型结构 model = CNN().to(device)
# 加载训练好的权重 try: model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() # 设置为评估模式 print("✅ 模型加载成功!") except Exception as e: print(f"❌ 模型加载失败: {e}") return None
# 打印模型信息 total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"📊 模型参数总数: {total_params:,}")
# 创建虚拟输入(与您的图片尺寸匹配) batch_size = 1 dummy_input = torch.randn(batch_size, 1, setting.IMAGE_HEIGHT, setting.IMAGE_WIDTH).to(device) print(f"📐 输入尺寸: {dummy_input.shape}")
# ONNX输出路径 onnx_path = "best_captcha_model.onnx" optimized_onnx_path = "captcha_model_optimized.onnx"
# 导出ONNX模型 try: print("🔄 正在导出ONNX模型...") torch.onnx.export( model, dummy_input, onnx_path, export_params=True, opset_version=13, # 使用较新的opset以获得更好优化 do_constant_folding=True, # 优化常量折叠 input_names=['input'], output_names=['output'], dynamic_axes={ 'input': {0: 'batch_size'}, 'output': {0: 'batch_size'} }, verbose=False ) print(f"✅ ONNX模型已导出: {onnx_path}") except Exception as e: print(f"❌ ONNX导出失败: {e}") return None
# 验证ONNX模型 try: print("🔍 验证ONNX模型...") onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) print("✅ ONNX模型验证通过!") except Exception as e: print(f"❌ ONNX模型验证失败: {e}") return None
# 优化模型(可选) try: import onnxoptimizer print("⚡ 正在优化ONNX模型...") passes = ['extract_constant_to_initializer', 'eliminate_unused_initializer'] optimized_model = onnxoptimizer.optimize(onnx_model, passes) onnx.save(optimized_model, optimized_onnx_path) print(f"✅ 优化模型已保存: {optimized_onnx_path}") except Exception as e: print(f"⚠️ 优化步骤跳过: {e}") optimized_onnx_path = onnx_path # 使用原始模型
# 测试ONNX模型推理 test_onnx_inference(optimized_onnx_path)
return optimized_onnx_path
def test_onnx_inference(onnx_path): """测试ONNX模型推理功能""" print("\n🧪 测试ONNX模型推理...")
try: # 创建推理会话 ort_session = ort.InferenceSession(onnx_path)
# 创建测试输入 test_input = np.random.randn(1, 1, setting.IMAGE_HEIGHT, setting.IMAGE_WIDTH).astype(np.float32)
# 进行推理 outputs = ort_session.run(None, {'input': test_input})
print("✅ ONNX推理测试成功!") print(f"📊 输出形状: {outputs[0].shape}") print(f"🎯 输出范围: [{outputs[0].min():.4f}, {outputs[0].max():.4f}]")
except Exception as e: print(f"❌ ONNX推理测试失败: {e}")
def compare_performance(onnx_path, original_model_path): """比较ONNX和原始PyTorch模型的性能""" print("\n⚡ 性能对比测试...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载原始PyTorch模型 original_model = CNN().to(device) original_model.load_state_dict(torch.load(original_model_path, map_location=device)) original_model.eval()
# 加载ONNX模型 ort_session = ort.InferenceSession(onnx_path)
# 创建测试数据 test_data = torch.randn(10, 1, setting.IMAGE_HEIGHT, setting.IMAGE_WIDTH).to(device) test_data_np = test_data.cpu().numpy().astype(np.float32)
# PyTorch推理时间 start_time = time.time() with torch.no_grad(): for i in range(100): _ = original_model(test_data) torch_time = time.time() - start_time
# ONNX推理时间 start_time = time.time() for i in range(100): _ = ort_session.run(None, {'input': test_data_np}) onnx_time = time.time() - start_time
print(f"⏱️ PyTorch推理时间: {torch_time:.4f}s") print(f"⏱️ ONNX推理时间: {onnx_time:.4f}s") print(f"🚀 速度提升: {torch_time / onnx_time:.2f}x")
class ONNXPredictor: """ONNX模型预测器"""
def __init__(self, onnx_path): self.onnx_path = onnx_path self.session = ort.InferenceSession(onnx_path) self.input_name = self.session.get_inputs()[0].name
def predict(self, image_array): """单张图片预测""" outputs = self.session.run(None, {self.input_name: image_array}) return outputs[0]
def predict_batch(self, image_arrays): """批量预测""" batch_outputs = [] for img_array in image_arrays: output = self.predict(img_array) batch_outputs.append(output) return np.array(batch_outputs)
if __name__ == "__main__": import time
# 解决OpenMP警告(如果出现) os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
start_time = time.time()
# 执行转换 onnx_path = export_model_to_onnx()
if onnx_path: end_time = time.time() print("\n" + "=" * 50) print("🎉 ONNX转换完成!") print(f"⏱️ 总耗时: {end_time - start_time:.2f}秒") print(f"💾 模型文件: {onnx_path}") print(f"📊 准确率: 73.23% (最佳模型)") print("\n💡 使用示例:") print(f"python inference_onnx.py --image 您的图片.jpg --model {onnx_path}") else: print("\n❌ ONNX转换失败") CNN训练图片验证码记录
https://fuwari.vercel.app/posts/ai/yzm/cnn_1005/