Getting started
For a quick hands-on we are going to go through creating a MNIST classifier step by step (examples/mnist.py
).
First we can see the model definition of a simple CNN classifier:
class Net(nn.Module):
def __init__(self, intermediate_hidden=50):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, intermediate_hidden)
self.fc2 = nn.Linear(intermediate_hidden, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return x
The get_data
function downloads MNIST and performs the preprocessing:
def get_data():
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train = MNIST(download=True, root=".", transform=data_transform, train=True)
val = MNIST(download=False, root=".", transform=data_transform, train=False)
return train, val
Finally the get_parser
function creates a CLI parser for the model arguments. Note the destination variable model.intermediate_hidden
(important for later):
def get_parser():
parser = ArgumentParser("MNIST classification example")
parser.add_argument(
"--hidden",
dest="model.intermediate_hidden",
type=int,
default=12,
help="Intermediate hidden layers for linear module",
)
return parser
Setup
You would create these functions whether you use slp or not. The interesting part is in the main function.
Here we first need to perform some setup for logging, configuration parsing and seeding. Let's go through this step-by-step:
First we create the CLI args parser. make_cli_parser
takes the parser we defined in get_parser() and extends it with generic arguments for the data, optimizers, learning rate scheduling, training and experiment tracking.
Run python mnist.py --help
for a full list of arguments, or go down in the Appendix.
The arguments have a detailed description and most of them default to None
so they will not be used.
parser = get_parser()
parser = make_cli_parser(parser, PLDataModuleFromDatasets)
Next we parse the configuration. Here we need to provide the parser and (optionally) a YAML configuration file python mnist.py --config my-config.yaml
.
The configuration file should have the following format:
model:
intermediate_hidden: 100
optimizer: Adam
optim:
lr: 1e-3
lr_scheduler: true # ReduceLROnPlateau
lr_schedule:
factor: 2
data:
batch_size: 128
batch_size_eval: 256
Note that this format, closely follows the dest
values we configure in the command line args (e.g. OPTIM.LR
, LR_SCHEDULE.FACTOR
), namely the dots. form a hierarchy.
This way we can use parse_config
to merge the values in the configuration file and the CLI args.
The precedence is as follows:
default CLI args < config file values < user provided CLI args
So if we call the script with --lr 1e-4
this value will overwrite the value in the configuration file.
If a value is not specified in the configuration file, the default value we specified in argparse will be set.
If a value is not specified in any of these places, sane defaults will be used.
config = parse_config(parser, parser.parse_args().config)
if config.trainer.experiment_name == "experiment":
config.trainer.experiment_name = "mnist-classification"
Next, we configure logging. This call configures loguru
to intercept all logs and print the both to stdout and a log file.
The log file name will depend on the experiment name we provided and datetime.now()
, to avoid overwriting previous runs (e.g. mnist-classification.20210302-134714.log
).
configure_logging(f"logs/{config.trainer.experiment_name}")
Finally, we make the run deterministic (--seed
)
if config.seed is not None:
logger.info("Seeding everything with seed={seed}")
pl.utilities.seed.seed_everything(seed=config.seed)
Data Module
Here we download the train and test datasets and define the LightningDataModule
that will be used in this experiment.
The LightningDataModule
is the preferred way to consume datasets in pytorch lightning, and PLDataModuleFromDatasets
abstracts the boilerplate of constructing and configuring the DataLoaders
, splitting data etc.
Note: PLDataModuleFromDatasets
expects three torch.utils.data.Datasets
as input, train, val and test. Val and test are optional. If any of the validation or tests sets are not provided, PLDataModuleFromDatasets
will create a split using 20% of the train set by default (see --val-percent
, --test-percent
).
train, test = get_data()
ldm = PLDataModuleFromDatasets(train, test=test, seed=config.seed, **config.data) # Note we pass **config.data, because config is in a hierarchy.
Defining the model
Next we define the model, optimizer, criterion and learning rate scheduler (pretty standard).
model = Net(**config.model)
optimizer = getattr(optim, config.optimizer)(model.parameters(), **config.optim)
criterion = nn.CrossEntropyLoss()
lr_scheduler = None
if config.lr_scheduler:
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, **config.lr_schedule
)
And the LightningModule
that will be used for training. This module takes care of defining the train and validation steps, computing and logging metrics etc.
Note that pytorch lightning by default expects softmaxed outputs in the predefined metrics, so we wrap the metric with FromLogits
to use crossentropy loss.
lm = PLModule(
model,
optimizer,
criterion,
lr_scheduler=lr_scheduler,
metrics={"acc": FromLogits(pl.metrics.classification.Accuracy())},
hparams=config, # We pass this so that configuration will be logged in wandb
)
Training and Debugging
Finally, we have the option to run a full training and testing phase of the model, or run a quick debug execution.
If we need to debug, we can pass --debug
flag, and the trainer will run a full training, validation run on 5 batches.
It will also try to overfit the model on 5 batches to verify that gradients flow.
# Run debugging session or fit & test the model ############
if config.debug:
logger.info("Running in debug mode: Fast run on 5 batches")
trainer = make_trainer(fast_dev_run=5)
trainer.fit(lm, datamodule=ldm)
logger.info("Running in debug mode: Overfitting 5 batches")
trainer = make_trainer(overfit_batches=5)
trainer.fit(lm, datamodule=ldm)
If we run in normal mode, we fit on train / val sets and evaluate the best model on the test set. The best model is selected as the model with the smallest validation loss.
Training will run with early stopping, the best 3 checkpoints will be saved, all the jazz.
Note watch_model
tells wandb to track weight norms and gradients for further inspection.
else:
trainer = make_trainer(**config.trainer)
watch_model(trainer, model)
trainer.fit(lm, datamodule=ldm)
trainer.test(ckpt_path="best", test_dataloaders=ldm.test_dataloader())
logger.info("Run finished. Uploading files to wandb...")
Sure, most of this goodness comes from the awesome team in pytorch lightning. But what we do here is we abstract the boilerplate and the large learning curve, without making sacrifices in the features.
For example, you can play spot the differences between examples/mnist.py
, which performs digit classification and examples/smt_bert.py
which finetunes BERT for sentiment classification on SST-2.
The classes we use change, but the way they are called, the structure and features remains the same.
# smt_bert.py
...
if __name__ == "__main__":
parser = get_parser()
parser = make_cli_parser(parser, PLDataModuleFromCorpus)
args = parser.parse_args()
config_file = args.config
config = parse_config(parser, config_file)
# Set these by default.
config.hugging_face_model = config.data.tokenizer
config.data.add_special_tokens = True
config.data.lower = "uncased" in config.hugging_face_model
if config.trainer.experiment_name == "experiment":
config.trainer.experiment_name = "finetune-bert-smt"
configure_logging(f"logs/{config.trainer.experiment_name}")
if config.seed is not None:
logger.info("Seeding everything with seed={seed}")
pl.utilities.seed.seed_everything(seed=config.seed)
(
raw_train,
labels_train,
raw_dev,
labels_dev,
raw_test,
labels_test,
num_labels,
) = get_data(config)
ldm = PLDataModuleFromCorpus(
raw_train,
labels_train,
val=raw_dev,
val_labels=labels_dev,
test=raw_test,
test_labels=labels_test,
collate_fn=collate_fn,
**config.data,
)
model = BertForSequenceClassification.from_pretrained(
config.hugging_face_model, num_labels=num_labels
)
logger.info(model)
# Leave this hardcoded for now.
optimizer = AdamW([p for p in model.parameters() if p.requires_grad], lr=1e-5)
criterion = nn.CrossEntropyLoss()
lm = BertPLModule(
model,
optimizer,
criterion,
metrics={"acc": FromLogits(pl.metrics.classification.Accuracy())},
)
trainer = make_trainer(**config.trainer)
watch_model(trainer, model)
trainer.fit(lm, datamodule=ldm)
trainer.test(ckpt_path="best", test_dataloaders=ldm.test_dataloader())
Appendix. Command Line arguments
usage: mnist.py [-h] [--hidden MODEL.INTERMEDIATE_HIDDEN]
[--optimizer {Adam,AdamW,SGD,Adadelta,Adagrad,Adamax,ASGD,RMSprop}] [--lr OPTIM.LR]
[--weight-decay OPTIM.WEIGHT_DECAY] [--lr-scheduler]
[--lr-factor LR_SCHEDULE.FACTOR] [--lr-patience LR_SCHEDULE.PATIENCE]
[--lr-cooldown LR_SCHEDULE.COOLDOWN] [--min-lr LR_SCHEDULE.MIN_LR] [--seed SEED]
[--config CONFIG] [--experiment-name TRAINER.EXPERIMENT_NAME]
[--run-id TRAINER.RUN_ID] [--experiment-group TRAINER.EXPERIMENT_GROUP]
[--experiments-folder TRAINER.EXPERIMENTS_FOLDER] [--save-top-k TRAINER.SAVE_TOP_K]
[--patience TRAINER.PATIENCE] [--wandb-project TRAINER.WANDB_PROJECT]
[--tags [TRAINER.TAGS [TRAINER.TAGS ...]]] [--stochastic_weight_avg]
[--gpus TRAINER.GPUS] [--val-interval TRAINER.CHECK_VAL_EVERY_N_EPOCH]
[--clip-grad-norm TRAINER.GRADIENT_CLIP_VAL] [--epochs TRAINER.MAX_EPOCHS]
[--steps TRAINER.MAX_STEPS] [--tbtt_steps TRAINER.TRUNCATED_BPTT_STEPS] [--debug]
[--val-percent DATA.VAL_PERCENT] [--test-percent DATA.TEST_PERCENT]
[--bsz DATA.BATCH_SIZE] [--bsz-eval DATA.BATCH_SIZE_EVAL]
[--num-workers DATA.NUM_WORKERS] [--pin-memory] [--drop-last] [--shuffle-eval]
optional arguments:
-h, --help show this help message and exit
--hidden MODEL.INTERMEDIATE_HIDDEN
Intermediate hidden layers for linear module
--optimizer {Adam,AdamW,SGD,Adadelta,Adagrad,Adamax,ASGD,RMSprop}
Which optimizer to use
--lr OPTIM.LR Learning rate
--weight-decay OPTIM.WEIGHT_DECAY
Learning rate
--lr-scheduler Use learning rate scheduling. Currently only ReduceLROnPlateau is supported
out of the box
--lr-factor LR_SCHEDULE.FACTOR
Multiplicative factor by which LR is reduced. Used if --lr-scheduler is
provided.
--lr-patience LR_SCHEDULE.PATIENCE
Number of epochs with no improvement after which learning rate will be
reduced. Used if --lr-scheduler is provided.
--lr-cooldown LR_SCHEDULE.COOLDOWN
Number of epochs to wait before resuming normal operation after lr has been
reduced. Used if --lr-scheduler is provided.
--min-lr LR_SCHEDULE.MIN_LR
Minimum lr for LR scheduling. Used if --lr-scheduler is provided.
--seed SEED Seed for reproducibility
--config CONFIG Path to YAML configuration file
--experiment-name TRAINER.EXPERIMENT_NAME
Name of the running experiment
--run-id TRAINER.RUN_ID
Unique identifier for the current run. If not provided it is inferred from
datetime.now()
--experiment-group TRAINER.EXPERIMENT_GROUP
Group of current experiment. Useful when evaluating for different seeds /
cross-validation etc.
--experiments-folder TRAINER.EXPERIMENTS_FOLDER
Top-level folder where experiment results & checkpoints are saved
--save-top-k TRAINER.SAVE_TOP_K
Save checkpoints for top k models
--patience TRAINER.PATIENCE
Number of epochs to wait before early stopping
--wandb-project TRAINER.WANDB_PROJECT
Wandb project under which results are saved
--tags [TRAINER.TAGS [TRAINER.TAGS ...]]
Tags for current run to make results searchable.
--stochastic_weight_avg
Use Stochastic weight averaging.
--gpus TRAINER.GPUS Number of GPUs to use
--val-interval TRAINER.CHECK_VAL_EVERY_N_EPOCH
Run validation every n epochs
--clip-grad-norm TRAINER.GRADIENT_CLIP_VAL
Clip gradients with ||grad(w)|| >= args.clip_grad_norm
--epochs TRAINER.MAX_EPOCHS
Maximum number of training epochs
--steps TRAINER.MAX_STEPS
Maximum number of training steps
--tbtt_steps TRAINER.TRUNCATED_BPTT_STEPS
Truncated Back-propagation-through-time steps.
--debug If true, we run a full run on a small subset of the input data and overfit
10 training batches
--val-percent DATA.VAL_PERCENT
Percent of validation data to be randomly split from the training set, if
no validation set is provided
--test-percent DATA.TEST_PERCENT
Percent of test data to be randomly split from the training set, if no test
set is provided
--bsz DATA.BATCH_SIZE
Training batch size
--bsz-eval DATA.BATCH_SIZE_EVAL
Evaluation batch size
--num-workers DATA.NUM_WORKERS
Number of workers to be used in the DataLoader
--pin-memory Pin data to GPU memory for faster data loading
--drop-last Drop last incomplete batch
--shuffle-eval Shuffle val & test sets