训练框架使用

框架导引 现有网络部分代码运行流程总结

这套框架设计得非常自动化,其核心是“配置驱动”。整个流程就像一条精密的自动化生产线,从读取您的 .yaml 配置文件开始。

第 1 步:加载蓝图 (读取配置)

  • 您在终端运行 python train.py --config-name=PatchEXNet,整个流程启动。

  • Hydra 框架首先介入,它会读取并解析 conf/PatchEXNet.yaml 文件。这个文件就是所有后续操作的“总设计蓝图”,规定了要用哪个数据集、哪个模型、哪个损失函数、学习率是多少等等。

第 2 步:准备原材料 (加载数据)

  • train.py 根据“蓝图”中的 train_datasetval_dataset 配置,初始化 DataInterface 模块。

  • DataInterface 接着根据配置,实例化 ZWTDataset_P 这个数据集类。

  • ZWTDataset_P 会读取配置中的 active_buffers 列表,得知这次任务需要加载哪些 .npy 文件(例如 WarpedFrame, HoleMask 等)。

  • 当训练开始时,DataLoader 会调用 ZWTDataset_P__getitem__ 方法,该方法从硬盘中读取单帧所需的所有 npy 文件,将它们打包成一个Python 字典并返回。

第 3 步:组装生产线 (初始化模型与训练器)

  • train.py 根据“蓝图”中的 model, loss, optimizer 等配置,初始化核心的 ModelInterface 模块。

  • ModelInterface__init__ 方法中:

    • 它调用 load_model_wrapper,根据 model.name ("PatchEXNet") 和 model.args,动态地创建 PatchEXNet 类的实例。

    • 它调用 load_loss,根据 loss.name ("PatchEXLoss"),创建 PatchEXLoss 类的实例。

    • 它还会根据配置,准备好优化器 (Adam) 和学习率调度器 (MultiStepLR)。

第 4 步:开机生产 (执行训练循环)

  • train.py 调用 trainer.fit(...),将控制权完全交给 PyTorch Lightning。

  • PyTorch Lightning 自动执行以下循环:

    1. DataLoader 中取出一个批次(batch)的数据(这是一个数据字典)。

    2. 调用 ModelInterfacetraining_step 方法,并将 batch 传给它。

    3. training_step 内部:

      • 首先调用模型自带的 decomp_batch 方法,将 batch 字典分解为 input 字典和 label 字典。

      • 然后调用 self.model(input),即 PatchEXNetforward 方法。在这里,模型内部负责将 input 字典中的张量拼接起来,并完成一次完整的前向传播,得到预测结果 pred

      • 接着调用 loss_input 方法,将 predlabelinput 整理成一个字典,送入损失函数。

      • 最后,调用 self.loss(...) 计算损失值,并通过 self.log(...) 将其发送到 TensorBoard。

    4. PyTorch Lightning 自动处理反向传播和优化器更新。

  • 在每个验证周期,会调用 validation_step,执行您为 PatchEXNet 添加的专属逻辑,计算 PSNR 并记录图像到 TensorBoard。 未来新增一个网络的修改清单 (SOP)

假设您未来要实现一个新的模型,名为 FutureNet。您只需要按照以下清单进行修改即可:

✅ 第 1 步:实现模型本身 (FutureNet.py)

  • model/ 目录下创建一个新文件 FutureNet.py

  • 在其中定义 class FutureNet(nn.Module),并实现其网络结构 (__init__forward)。

  • 【关键】 必须为 FutureNet 类实现框架所需的两个辅助方法:decomp_batch(self, batch, device)loss_input(self, prediction, label, input)

✅ 第 2 步:实现对应的损失函数 (FutureNetLoss)

  • 打开 ModelLoss.py 文件。 +
  • 添加一个新的损失函数类 class FutureNetLoss(nn.Module)

  • 确保其 forward 方法能够正确处理由 FutureNet.loss_input 方法生成的字典。

✅ 第 3 步:创建新的配置文件 (FutureNet.yaml)

  • conf/ 目录下,复制 PatchEXNet.yaml 并重命名为 FutureNet.yaml

  • 逐项修改这个新文件:

    • train_setting.name: 改为一个新的、有意义的实验名称,如 FutureNet/InitialTest

    • train_dataset.active_buffers: 更新此列表,使其包含 FutureNet 训练所需的所有输入和标签的确切名称。

    • model.name: 修改为 FutureNet

    • model.args: 根据 FutureNet__init__ 方法,提供必要的参数。

    • loss.name: 修改为 FutureNetLoss

    • optimizer / scheduler: 根据新模型的需要,调整学习率等参数。

✅ 第 4 步:让框架“认识”新组件

  • 您的框架通过 load_model_wrapperload_loss 等函数动态加载模块。您需要找到这些函数(可能在 model/loader.py 或类似文件中),并将新的 FutureNetFutureNetLoss 添加到它们的“可识别列表”中。通常是在一个字典或 if/elif 结构中加入新的分支。

  • 同时,在 model_interface.py 文件顶部,添加 from model.FutureNet import FutureNet,以便 isinstance 检查能够正常工作。

✅ 第 5 步:添加验证和测试的专属逻辑

  • 打开 model_interface.py

  • validation_steptest_step 函数的 if/isinstance 链中,FutureNet 添加新的 elif isinstance(self.model, FutureNet): 分支

  • 在这个新的代码块中,编写用于计算性能指标(如 PSNR)、记录图像到 TensorBoard 以及保存测试结果的专属代码。

✅ 第 6 步:运行!

  • 训练: python train.py --config-name=FutureNet

  • 测试: (在训练完成后,更新 .yaml 文件中的 load_pathpython test.py --config-name=FutureNet