训练框架使用
框架导引 现有网络部分代码运行流程总结
这套框架设计得非常自动化,其核心是“配置驱动”。整个流程就像一条精密的自动化生产线,从读取您的 .yaml 配置文件开始。
第 1 步:加载蓝图 (读取配置)
-
您在终端运行
python train.py --config-name=PatchEXNet,整个流程启动。 -
Hydra 框架首先介入,它会读取并解析
conf/PatchEXNet.yaml文件。这个文件就是所有后续操作的“总设计蓝图”,规定了要用哪个数据集、哪个模型、哪个损失函数、学习率是多少等等。
第 2 步:准备原材料 (加载数据)
-
train.py根据“蓝图”中的train_dataset和val_dataset配置,初始化DataInterface模块。 -
DataInterface接着根据配置,实例化ZWTDataset_P这个数据集类。 -
ZWTDataset_P会读取配置中的active_buffers列表,得知这次任务需要加载哪些.npy文件(例如WarpedFrame,HoleMask等)。 -
当训练开始时,
DataLoader会调用ZWTDataset_P的__getitem__方法,该方法从硬盘中读取单帧所需的所有npy文件,将它们打包成一个Python 字典并返回。
第 3 步:组装生产线 (初始化模型与训练器)
-
train.py根据“蓝图”中的model,loss,optimizer等配置,初始化核心的ModelInterface模块。 -
在
ModelInterface的__init__方法中:-
它调用
load_model_wrapper,根据model.name("PatchEXNet") 和model.args,动态地创建PatchEXNet类的实例。 -
它调用
load_loss,根据loss.name("PatchEXLoss"),创建PatchEXLoss类的实例。 -
它还会根据配置,准备好优化器 (Adam) 和学习率调度器 (MultiStepLR)。
-
第 4 步:开机生产 (执行训练循环)
-
train.py调用trainer.fit(...),将控制权完全交给 PyTorch Lightning。 -
PyTorch Lightning 自动执行以下循环:
-
从
DataLoader中取出一个批次(batch)的数据(这是一个数据字典)。 -
调用
ModelInterface的training_step方法,并将batch传给它。 -
在
training_step内部:-
首先调用模型自带的
decomp_batch方法,将batch字典分解为input字典和label字典。 -
然后调用
self.model(input),即PatchEXNet的forward方法。在这里,模型内部负责将input字典中的张量拼接起来,并完成一次完整的前向传播,得到预测结果pred。 -
接着调用
loss_input方法,将pred、label和input整理成一个字典,送入损失函数。 -
最后,调用
self.loss(...)计算损失值,并通过self.log(...)将其发送到 TensorBoard。
-
-
PyTorch Lightning 自动处理反向传播和优化器更新。
-
-
在每个验证周期,会调用
validation_step,执行您为PatchEXNet添加的专属逻辑,计算 PSNR 并记录图像到 TensorBoard。 未来新增一个网络的修改清单 (SOP)
假设您未来要实现一个新的模型,名为 FutureNet。您只需要按照以下清单进行修改即可:
✅ 第 1 步:实现模型本身 (FutureNet.py)
-
在
model/目录下创建一个新文件FutureNet.py。 -
在其中定义
class FutureNet(nn.Module),并实现其网络结构 (__init__和forward)。 -
【关键】 必须为
FutureNet类实现框架所需的两个辅助方法:decomp_batch(self, batch, device)和loss_input(self, prediction, label, input)。
✅ 第 2 步:实现对应的损失函数 (FutureNetLoss)
- 打开
ModelLoss.py文件。 + -
添加一个新的损失函数类
class FutureNetLoss(nn.Module)。 -
确保其
forward方法能够正确处理由FutureNet.loss_input方法生成的字典。
✅ 第 3 步:创建新的配置文件 (FutureNet.yaml)
-
在
conf/目录下,复制PatchEXNet.yaml并重命名为FutureNet.yaml。 -
逐项修改这个新文件:
-
train_setting.name: 改为一个新的、有意义的实验名称,如FutureNet/InitialTest。 -
train_dataset.active_buffers: 更新此列表,使其包含FutureNet训练所需的所有输入和标签的确切名称。 -
model.name: 修改为FutureNet。 -
model.args: 根据FutureNet的__init__方法,提供必要的参数。 -
loss.name: 修改为FutureNetLoss。 -
optimizer/scheduler: 根据新模型的需要,调整学习率等参数。
-
✅ 第 4 步:让框架“认识”新组件
-
您的框架通过
load_model_wrapper和load_loss等函数动态加载模块。您需要找到这些函数(可能在model/loader.py或类似文件中),并将新的FutureNet和FutureNetLoss添加到它们的“可识别列表”中。通常是在一个字典或if/elif结构中加入新的分支。 -
同时,在
model_interface.py文件顶部,添加from model.FutureNet import FutureNet,以便isinstance检查能够正常工作。
✅ 第 5 步:添加验证和测试的专属逻辑
-
打开
model_interface.py。 -
在
validation_step和test_step函数的if/isinstance链中,为FutureNet添加新的elif isinstance(self.model, FutureNet):分支。 -
在这个新的代码块中,编写用于计算性能指标(如 PSNR)、记录图像到 TensorBoard 以及保存测试结果的专属代码。
✅ 第 6 步:运行!
-
训练:
python train.py --config-name=FutureNet -
测试: (在训练完成后,更新
.yaml文件中的load_path)python test.py --config-name=FutureNet