训练框架使用
框架导引 现有网络部分代码运行流程总结
这套框架设计得非常自动化,其核心是“配置驱动”。整个流程就像一条精密的自动化生产线,从读取您的 .yaml
配置文件开始。
第 1 步:加载蓝图 (读取配置)
-
您在终端运行
python train.py --config-name=PatchEXNet
,整个流程启动。 -
Hydra 框架首先介入,它会读取并解析
conf/PatchEXNet.yaml
文件。这个文件就是所有后续操作的“总设计蓝图”,规定了要用哪个数据集、哪个模型、哪个损失函数、学习率是多少等等。
第 2 步:准备原材料 (加载数据)
-
train.py
根据“蓝图”中的train_dataset
和val_dataset
配置,初始化DataInterface
模块。 -
DataInterface
接着根据配置,实例化ZWTDataset_P
这个数据集类。 -
ZWTDataset_P
会读取配置中的active_buffers
列表,得知这次任务需要加载哪些.npy
文件(例如WarpedFrame
,HoleMask
等)。 -
当训练开始时,
DataLoader
会调用ZWTDataset_P
的__getitem__
方法,该方法从硬盘中读取单帧所需的所有npy
文件,将它们打包成一个Python 字典并返回。
第 3 步:组装生产线 (初始化模型与训练器)
-
train.py
根据“蓝图”中的model
,loss
,optimizer
等配置,初始化核心的ModelInterface
模块。 -
在
ModelInterface
的__init__
方法中:-
它调用
load_model_wrapper
,根据model.name
("PatchEXNet") 和model.args
,动态地创建PatchEXNet
类的实例。 -
它调用
load_loss
,根据loss.name
("PatchEXLoss"),创建PatchEXLoss
类的实例。 -
它还会根据配置,准备好优化器 (Adam) 和学习率调度器 (MultiStepLR)。
-
第 4 步:开机生产 (执行训练循环)
-
train.py
调用trainer.fit(...)
,将控制权完全交给 PyTorch Lightning。 -
PyTorch Lightning 自动执行以下循环:
-
从
DataLoader
中取出一个批次(batch)的数据(这是一个数据字典)。 -
调用
ModelInterface
的training_step
方法,并将batch
传给它。 -
在
training_step
内部:-
首先调用模型自带的
decomp_batch
方法,将batch
字典分解为input
字典和label
字典。 -
然后调用
self.model(input)
,即PatchEXNet
的forward
方法。在这里,模型内部负责将input
字典中的张量拼接起来,并完成一次完整的前向传播,得到预测结果pred
。 -
接着调用
loss_input
方法,将pred
、label
和input
整理成一个字典,送入损失函数。 -
最后,调用
self.loss(...)
计算损失值,并通过self.log(...)
将其发送到 TensorBoard。
-
-
PyTorch Lightning 自动处理反向传播和优化器更新。
-
-
在每个验证周期,会调用
validation_step
,执行您为PatchEXNet
添加的专属逻辑,计算 PSNR 并记录图像到 TensorBoard。 未来新增一个网络的修改清单 (SOP)
假设您未来要实现一个新的模型,名为 FutureNet
。您只需要按照以下清单进行修改即可:
✅ 第 1 步:实现模型本身 (FutureNet.py
)
-
在
model/
目录下创建一个新文件FutureNet.py
。 -
在其中定义
class FutureNet(nn.Module)
,并实现其网络结构 (__init__
和forward
)。 -
【关键】 必须为
FutureNet
类实现框架所需的两个辅助方法:decomp_batch(self, batch, device)
和loss_input(self, prediction, label, input)
。
✅ 第 2 步:实现对应的损失函数 (FutureNetLoss
)
- 打开
ModelLoss.py
文件。 + -
添加一个新的损失函数类
class FutureNetLoss(nn.Module)
。 -
确保其
forward
方法能够正确处理由FutureNet.loss_input
方法生成的字典。
✅ 第 3 步:创建新的配置文件 (FutureNet.yaml
)
-
在
conf/
目录下,复制PatchEXNet.yaml
并重命名为FutureNet.yaml
。 -
逐项修改这个新文件:
-
train_setting.name
: 改为一个新的、有意义的实验名称,如FutureNet/InitialTest
。 -
train_dataset.active_buffers
: 更新此列表,使其包含FutureNet
训练所需的所有输入和标签的确切名称。 -
model.name
: 修改为FutureNet
。 -
model.args
: 根据FutureNet
的__init__
方法,提供必要的参数。 -
loss.name
: 修改为FutureNetLoss
。 -
optimizer
/scheduler
: 根据新模型的需要,调整学习率等参数。
-
✅ 第 4 步:让框架“认识”新组件
-
您的框架通过
load_model_wrapper
和load_loss
等函数动态加载模块。您需要找到这些函数(可能在model/loader.py
或类似文件中),并将新的FutureNet
和FutureNetLoss
添加到它们的“可识别列表”中。通常是在一个字典或if/elif
结构中加入新的分支。 -
同时,在
model_interface.py
文件顶部,添加from model.FutureNet import FutureNet
,以便isinstance
检查能够正常工作。
✅ 第 5 步:添加验证和测试的专属逻辑
-
打开
model_interface.py
。 -
在
validation_step
和test_step
函数的if/isinstance
链中,为FutureNet
添加新的elif isinstance(self.model, FutureNet):
分支。 -
在这个新的代码块中,编写用于计算性能指标(如 PSNR)、记录图像到 TensorBoard 以及保存测试结果的专属代码。
✅ 第 6 步:运行!
-
训练:
python train.py --config-name=FutureNet
-
测试: (在训练完成后,更新
.yaml
文件中的load_path
)python test.py --config-name=FutureNet