# 人造数据** 主要内容:** 介绍了如何在面向对象的设计下,实现 d2l 框架中的 DataModule 的方法
** 人造数据的意义:** 用于评估模型的能力
# 如何生成数据集?自定义一个类,并继承自 DataModule,在 init 中完成数据的生成
例子:
1 2 3 4 5 6 7 8 9 10 class SyntheticRegressionData (d2l.DataModule): """Synthetic data for linear regression.""" def __init__ (self, w, b, noise=0.01 , num_train=1000 , num_val=1000 , batch_size=32 ): super ().__init__() self.save_hyperparameters() n = num_train + num_val self.X = torch.randn(n, len (w)) noise = torch.randn(n, 1 ) * noise self.y = torch.matmul(self.X, w.reshape((-1 , 1 ))) + b + noise
调用 init 和 save_hyperparameters 将参数变为类的属性
# 如何加载数据集?使用 add_to_class 修饰器重载 get_dataloader 方法
例子:
1 2 3 4 5 6 7 8 9 10 11 @d2l.add_to_class(SyntheticRegressionData ) def get_dataloader (self, train ): if train: indices = list (range (0 , self.num_train)) random.shuffle(indices) else : indices = list (range (self.num_train, self.num_train+self.num_val)) for i in range (0 , len (indices), self.batch_size): batch_indices = torch.tensor(indices[i: i+self.batch_size]) yield self.X[batch_indices], self.y[batch_indices]
train 用来判断数据是训练集(随机)还是验证集(按顺序)
# 更加简洁且高效的实现使用 Pytorch 等框架中内置的 API 加载数据
更高效,功能更多
实现:
1 2 3 4 5 6 @d2l.add_to_class(d2l.DataModule ) def get_tensorloader (self, tensors, train, indices=slice (0 , None ) ): tensors = tuple (a[indices] for a in tensors) dataset = torch.utils.data.TensorDataset(*tensors) return torch.utils.data.DataLoader(dataset, self.batch_size, shuffle=train)
tensors 是含有若干张量的元组,将 tensors 中每一个张量进行切片,再整合为新的 tensors
dataset 以新的 tensors 作为数据
最终返回一个内置的 DataLoader
重载 get_dataloader:
1 2 3 4 @d2l.add_to_class(SyntheticRegressionData ) def get_dataloader (self, train ): i = slice (0 , self.num_train) if train else slice (self.num_train, None ) return self.get_tensorloader((self.X, self.y), train, i)
[end]
mofianger
2024/2/1
代码引自 en.d2l.ai
参考:3.3. Synthetic Regression Data — Dive into Deep Learning 1.0.3 documentation (d2l.ai)