# 面向对象的设计

** 主要内容:** 介绍了 d2l 库中面向对象的设计,包括 Utilities, Module, DataModule, Trainer 等设计

(i) Module contains models, losses, and optimization methods; (ii) DataModule provides data loaders for training and validation; (iii) both classes are combined using the Trainer class, which allows us to train models on a variety of hardware platforms. --d2l

Module 包含 模型,损失函数,优化方法

DataModule 提供加载数据的方法

Trainer 将二者结合,并允许在各种硬件平台上训练

# 工具 Utilities

  • add_to_class()

    1
    2
    3
    4
    5
    def add_to_class(Class):  #@save
    """Register functions as methods in created class."""
    def wrapper(obj):
    setattr(Class, obj.__name__, obj)
    return wrapper

    装饰器,用于向创建的类中添加方法

    对于先前已实例化的对象仍有效果,即也可以调用新添加的方法

    例子:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    class A:
    def __init__(self):
    self.b = 1

    a = A()

    @add_to_class(A)
    def do(self):
    print('Class attribute "b" is', self.b)

    a.do()

  • class HyperParamters

    1
    2
    3
    4
    class HyperParameters:  #@save
    """The base class of hyperparameters."""
    def save_hyperparameters(self, ignore=[]):
    raise NotImplemented

    用于将类__init__方法的参数保存为类的属性

    • save_hyperparamters()

      例子:

      1
      2
      3
      4
      5
      6
      7
      8
      # Call the fully implemented HyperParameters class saved in d2l
      class B(d2l.HyperParameters):
      def __init__(self, a, b, c):
      self.save_hyperparameters(ignore=['c'])
      print('self.a =', self.a, 'self.b =', self.b)
      print('There is no self.c =', not hasattr(self, 'c'))

      b = B(a=1, b=2, c=3)

      输出:

      1
      2
      self.a = 1 self.b = 2
      There is no self.c = True

  • class ProgressBoard

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    class ProgressBoard(d2l.HyperParameters):  #@save
    """The board that plots data points in animation."""
    def __init__(self, xlabel=None, ylabel=None, xlim=None,
    ylim=None, xscale='linear', yscale='linear',
    ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
    fig=None, axes=None, figsize=(3.5, 2.5), display=True):
    self.save_hyperparameters()

    def draw(self, x, y, label, every_n=1):
    raise NotImplemented

    用于动态绘制实验进度(进行中)

    • draw()

      例子:

      1
      2
      3
      4
      board = d2l.ProgressBoard('x')
      for x in np.arange(0, 10, 0.1):
      board.draw(x, np.sin(x), 'sin', every_n=2)
      board.draw(x, np.cos(x), 'cos', every_n=10)

# 模块

class Module

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Module(nn.Module, d2l.HyperParameters):  #@save
"""The base class of models."""
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()

def loss(self, y_hat, y):
raise NotImplementedError

def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)

def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
self.board.draw(x, value.to(d2l.cpu()).detach().numpy(),
('train_' if train else 'val_') + key,
every_n=int(n))

def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l

def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)

def configure_optimizers(self):
raise NotImplementedError

所有将要实现的模型的基础

__init__ 储存将要学习的参数 (learnable parameters)

training_step 接受一批数据,返回 loss

configure_optimizers 返回优化器,或者它们的一个列表

# 数据

class DataModule

1
2
3
4
5
6
7
8
9
10
11
12
13
class DataModule(d2l.HyperParameters):  #@save
"""The base class of data."""
def __init__(self, root='../data', num_workers=4):
self.save_hyperparameters()

def get_dataloader(self, train):
raise NotImplementedError

def train_dataloader(self):
return self.get_dataloader(train=True)

def val_dataloader(self):
return self.get_dataloader(train=False)

get_dataloader 返回一个加载数据集的生成器(generator)

# 训练

class Trainer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Trainer(d2l.HyperParameters):  #@save
"""The base class for training models with data."""
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
assert num_gpus == 0, 'No GPU support yet'

def prepare_data(self, data):
self.train_dataloader = data.train_dataloader()
self.val_dataloader = data.val_dataloader()
self.num_train_batches = len(self.train_dataloader)
self.num_val_batches = (len(self.val_dataloader)
if self.val_dataloader is not None else 0)

def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
self.model = model

def fit(self, model, data):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()

def fit_epoch(self):
raise NotImplementedError

fit 接受 model(class Module)和 data(class DataModule),并迭代 max_epochs 次训练

[end]

mofianger

2024/2/1

图片代码引自 en.d2l.ai

参考:3.2. Object-Oriented Design for Implementation — Dive into Deep Learning 1.0.3 documentation (d2l.ai)