pytorch_lightning使用体验
如果是一些小模型想要快速实验,不想怎么写代码的,可以通过pytorch_lightning快速搭建模型,但是如果涉及到大模型,以及分布式训练预测,咱还是老老实实用pytorch吧。
一. 使用体验
就像很多年前写过tensorflow之后看到keras后的欣喜,当我看到pytorch_lightning后瞬间就喜欢上了它!对于pytorch的重度使用者来说,每次都要写很多重复的训练预测代码,总感觉代码复用起来很麻烦,于是pytorch_lightning它来啦!
pytorch_lightning的优势:
-
代码可读性、复用性高
-
自由度和pytorch一样高,并没有像使用keras一样感觉封装过死的感觉。
-
能像keras一样快速搭建模型,简化模型训练和预测的过程
-
支持分布式训练
二. 安装和使用
官网地址是:https://lightning.ai/
pip进行安装:pip show pytorch_lightning
下面使用MNIST来展示如何使用pytorch_lightning来简化自己的代码
2.1 数据模块LightningDataModule
通常情况下,我们需要做一些预处理,以及在定义完自己的dataset后,需要定义dataloader,这里可以直接继承LightningDataModule模块,直接重写其中的方法即可。
class MNISTDataModule(LightningDataModule):
def __init__(self,root_dir,val_size,num_workers,batch_size):
super(MNISTDataModule, self).__init__()
self.save_hyperparameters()
def prepare_data(self):
"""
download data once
"""
MNIST(self.hparams.root_dir, train=True, download=True)
MNIST(self.hparams.root_dir, train=False, download=True)
def setup(self, stage=None):
"""
setup dataset for each machine
"""
dataset = MNIST(self.hparams.root_dir,
train=True,
download=False,
transform=T.ToTensor())
train_length = len(dataset)
self.train_dataset, self.val_dataset = \
random_split(dataset,
[train_length - self.hparams.val_size, self.hparams.val_size])
def train_dataloader(self):
return DataLoader(self.train_dataset,
shuffle=True,
num_workers=self.hparams.num_workers,
batch_size=self.hparams.batch_size,
pin_memory=True)
def val_dataloader(self):
return DataLoader(self.val_dataset,
shuffle=False,
num_workers=self.hparams.num_workers,
batch_size=self.hparams.batch_size,
pin_memory=True)
2.2 训练和预测模块LightningModule
之前每次训练和预测模型的时候,我都会写一个该过程的一个基类,来封装每个epoch模型训练、验证的过程,其实每次不同的项目、不同的模型继承了上述的基类,但是基本上也就是改变其中的每个batch训练、验证的方法,然后看到了一个别人封装的这么完美的训练预测基类,简直开心的不要不要的。。
class MNISTModel(LightningModule):
def __init__(self, hidden_dim, num_classes, lr, num_epochs):
super().__init__()
self.save_hyperparameters()
self.net = LinearModel(self.hparams.hidden_dim)
self.training_step_outputs = []
self.validation_step_outputs = []
self.loss_fn = nn.CrossEntropyLoss()
self.accuracy = torchmetrics.Accuracy(
task="multiclass", num_classes=self.hparams.num_classes
)
self.f1_score = torchmetrics.F1Score(task="multiclass", num_classes=self.hparams.num_classes)
def forward(self, x):
return self.net(x)
def configure_optimizers(self):
self.optimizer = Adam(self.net.parameters(), lr=self.hparams.lr)
scheduler = CosineAnnealingLR(self.optimizer,
T_max=self.hparams.num_epochs,
eta_min=self.hparams.lr / 1e2)
return [self.optimizer], [scheduler]
def lr_scheduler_step(self, scheduler, *args, **kwargs):
scheduler.step()
def _common_step(self, batch, batch_idx):
images, labels = batch
logits_predicted = self(images)
loss = self.loss_fn(logits_predicted, labels)
acc = self.accuracy(logits_predicted, labels)
# acc = torch.sum(torch.eq(torch.argmax(logits_predicted, -1), labels).to(torch.float32)) / len(labels)
return loss, acc
def training_step(self, batch, batch_idx):
loss, acc = self._common_step(batch,batch_idx)
self.log('lr', get_learning_rate(self.optimizer))
self.log('train_step_loss', loss)
train_rs = {'train_loss': loss,
'train_acc': acc}
self.training_step_outputs.append(train_rs)
return loss
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
rs = self(batch[0])
rs = torch.argmax(rs, -1).numpy().tolist()
return rs
def validation_step(self, batch, batch_idx):
loss, acc = self._common_step(batch, batch_idx)
log = {'val_loss': loss,
'val_acc': acc}
self.validation_step_outputs.append(log)
return log
2.3 callbacks
这里如果我们觉得上面这些无法满足我们的日常训练、预测的需求,那么完全可以再增加一些其他需要的第三方和自己定义的callbacks,当然pytorch_lightning其实已经封装了很多常用的callbacks了,比如下面的几个常用的:
-
模型定义怎么保存ckpt:
ModelCheckpoint
-
如何定义训练及早停止:
MINISTCallBack
-
定义进度条:
TQDMProgressBar
当然了,我们想定义属于自己的callback怎么弄呢:
class MINISTCallBack(Callback):
def __init__(self):
super(MINISTCallBack, self).__init__()
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
print("Predict is ending")
def on_train_epoch_end(self, trainer : "pl.Trainer", pl_module: "pl.LightningModule"):
epoch_mean_loss = torch.stack([x['train_loss'] for x in pl_module.training_step_outputs]).mean()
epoch_mean_acc = torch.stack([x['train_acc'] for x in pl_module.training_step_outputs]).mean()
pl_module.log("train/loss", epoch_mean_loss, prog_bar=True)
pl_module.log("train/acc", epoch_mean_acc, prog_bar=True)
pl_module.training_step_outputs.clear()
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
epoch_mean_loss = torch.stack([x['val_loss'] for x in pl_module.validation_step_outputs]).mean()
epoch_mean_acc = torch.stack([x['val_acc'] for x in pl_module.validation_step_outputs]).mean()
pl_module.log('val/loss', epoch_mean_loss, prog_bar=True)
pl_module.log('val/acc', epoch_mean_acc, prog_bar=True)
pl_module.validation_step_outputs.clear()
2.4 调用
当我们都写完了上述我们定义好的数据模块,训练预测模块,那么如何使用呢?pytorch_lightning这里用了一个专门的类Trainer来调用。
训练调用:
trainer = Trainer(max_epochs=config.num_epochs,
# resume_from_checkpoint = 'ckpts/exp3/epoch=7.ckpt', # 断点续训
callbacks=callbacks,
logger=logger,
enable_model_summary=True, # 显示模型构造
accelerator='auto',
devices=1, # 多少个设备
deterministic=True,
num_sanity_val_steps=1, # 正式训练之前跑一次validation 测试程序是否出错
benchmark=True, # cudnn加速训练(要确保每个batch同一个大小)
)
# mnist_model.load_from_checkpoint('ckpts/exp3/epoch=7.ckpt')
trainer.fit(mnist_model,mnist_data)
预测调用,可以定义一个dataloader,也可以定义测试的数据模块,同时也能直接对单一一个tensor作为输入,进行预测:
rs = trainer.predict(mnist_model, dataloaders=test_loader)
rs = trainer.predict(mnist_model, datamodule=test_datamodule)
三. 分布式训练
pytorch_lightning也支持分布式,但是它只支持pytorch原生的DDP,作为被HuggingFace的accelerate圈粉的我。。。只能退坑了,拜拜👋🏻