# 面向对象的设计
** 主要内容:** 介绍了 d2l 库中面向对象的设计,包括 Utilities, Module, DataModule, Trainer 等设计
(i)
Module
contains models, losses, and optimization methods; (ii)DataModule
provides data loaders for training and validation; (iii) both classes are combined using theTrainer
class, which allows us to train models on a variety of hardware platforms. --d2l
Module 包含 模型,损失函数,优化方法
DataModule 提供加载数据的方法
Trainer 将二者结合,并允许在各种硬件平台上训练
# 工具 Utilities
add_to_class()
1
2
3
4
5def add_to_class(Class): #@save
"""Register functions as methods in created class."""
def wrapper(obj):
setattr(Class, obj.__name__, obj)
return wrapper装饰器,用于向创建的类中添加方法
对于先前已实例化的对象仍有效果,即也可以调用新添加的方法
例子:
1
2
3
4
5
6
7
8
9
10
11class A:
def __init__(self):
self.b = 1
a = A()
def do(self):
print('Class attribute "b" is', self.b)
a.do()class HyperParamters
1
2
3
4class HyperParameters: #@save
"""The base class of hyperparameters."""
def save_hyperparameters(self, ignore=[]):
raise NotImplemented用于将类__init__方法的参数保存为类的属性
save_hyperparamters()
例子:
1
2
3
4
5
6
7
8# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
def __init__(self, a, b, c):
self.save_hyperparameters(ignore=['c'])
print('self.a =', self.a, 'self.b =', self.b)
print('There is no self.c =', not hasattr(self, 'c'))
b = B(a=1, b=2, c=3)输出:
1
2self.a = 1 self.b = 2
There is no self.c = True
class ProgressBoard
1
2
3
4
5
6
7
8
9
10class ProgressBoard(d2l.HyperParameters): #@save
"""The board that plots data points in animation."""
def __init__(self, xlabel=None, ylabel=None, xlim=None,
ylim=None, xscale='linear', yscale='linear',
ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
fig=None, axes=None, figsize=(3.5, 2.5), display=True):
self.save_hyperparameters()
def draw(self, x, y, label, every_n=1):
raise NotImplemented用于动态绘制实验进度(进行中)
draw()
例子:
1
2
3
4board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
board.draw(x, np.sin(x), 'sin', every_n=2)
board.draw(x, np.cos(x), 'cos', every_n=10)
# 模块
class Module
1 | class Module(nn.Module, d2l.HyperParameters): #@save |
所有将要实现的模型的基础
__init__ 储存将要学习的参数 (learnable parameters)
training_step 接受一批数据,返回 loss
configure_optimizers 返回优化器,或者它们的一个列表
# 数据
class DataModule
1 | class DataModule(d2l.HyperParameters): #@save |
get_dataloader 返回一个加载数据集的生成器(generator)
# 训练
class Trainer
1 | class Trainer(d2l.HyperParameters): #@save |
fit 接受 model(class Module)和 data(class DataModule),并迭代 max_epochs 次训练
[end]
mofianger
2024/2/1
图片代码引自 en.d2l.ai
参考:3.2. Object-Oriented Design for Implementation — Dive into Deep Learning 1.0.3 documentation (d2l.ai)