Skip to content

Pytorch Lightning Bindings

These bindings help to build multi-purpose LightningModules and LightningDataModules, that can be utilized for many tasks / datasets.

Note, this is not in line with the pytorch-lightning mantra where everything about an experiment should be contained in a single module.

I agree this can help reproducibility, but I find it tedious to always copy and paste or even worse rewrite boilerplate code for metric calculation and implementing hooks.

With the extensive logging and sane configuration management in slp, reproducibility while developing new models is less of an issue.

My current workflow is to use these modules for fast development cycles, and when I need to publish a specific model, I can then copy and paste it into an isolated LightningModule to make it easier for the future reader. This way we can have the best of both worlds

PLDataModuleFromCorpus

embeddings: Optional[numpy.ndarray] property readonly

Embeddings matrix

Returns:

Type Description
Optional[numpy.ndarray]

Optional[np.ndarray]: Embeddings matrix

vocab_size: int property readonly

Number of tokens in the vocabulary

Returns:

Type Description
int

int: Number of tokens in the vocabulary

__init__(self, train, train_labels=None, val=None, val_labels=None, test=None, test_labels=None, val_percent=0.2, test_percent=0.2, batch_size=64, batch_size_eval=None, seed=None, num_workers=1, pin_memory=True, drop_last=False, shuffle_eval=False, sampler_train=None, sampler_val=None, sampler_test=None, batch_sampler_train=None, batch_sampler_val=None, batch_sampler_test=None, collate_fn=None, language_model=False, tokenizer='spacy', no_test_set=False, **corpus_args) special

Wrap raw corpus in a LightningDataModule

  • This handles the selection of the appropriate corpus class based on the tokenizer argument.
  • If language_model=True it uses the appropriate dataset from slp.data.datasets.
  • Uses the PLDataModuleFromDatasets to split the val and test sets if not provided

Parameters:

Name Type Description Default
train List

Raw train corpus

required
train_labels Optional[List]

Train labels. Defaults to None.

None
val Optional[List]

Raw validation corpus. Defaults to None.

None
val_labels Optional[List]

Validation labels. Defaults to None.

None
test Optional[List]

Raw test corpus. Defaults to None.

None
test_labels Optional[List]

Test labels. Defaults to None.

None
val_percent float

Percent of train to be used for validation if no validation set is given. Defaults to 0.2.

0.2
test_percent float

Percent of train to be used for test set if no test set is given. Defaults to 0.2.

0.2
batch_size int

Training batch size. Defaults to 1.

64
batch_size_eval int

Validation and test batch size. Defaults to None.

None
seed int

Seed for deterministic run. Defaults to None.

None
num_workers int

Number of workers in the DataLoader. Defaults to 1.

1
pin_memory bool

Pin tensors to GPU memory. Defaults to True.

True
drop_last bool

Drop last incomplete batch. Defaults to False.

False
sampler_train Sampler

Sampler for train loader. Defaults to None.

None
sampler_val Sampler

Sampler for validation loader. Defaults to None.

None
sampler_test Sampler

Sampler for test loader. Defaults to None.

None
batch_sampler_train BatchSampler

Batch sampler for train loader. Defaults to None.

None
batch_sampler_val BatchSampler

Batch sampler for validation loader. Defaults to None.

None
batch_sampler_test BatchSampler

Batch sampler for test loader. Defaults to None.

None
shuffle_eval bool

Shuffle validation and test dataloaders. Defaults to False.

False
collate_fn Optional[Callable[..., Any]]

Collator function. Defaults to None.

None
language_model bool

Use corpus for Language Modeling. Defaults to False.

False
tokenizer str

Select one of the cls.accepted_tokenizers. Defaults to "spacy".

'spacy'
no_test_set bool

Do not create test set. Useful for tuning

False
**corpus_args kwargs

Extra arguments to be passed to the corpus. See slp/data/corpus.py

{}

Exceptions:

Type Description
ValueError

[description]

ValueError

[description]

Source code in slp/plbind/dm.py
def __init__(
    self,
    train: List,
    train_labels: Optional[List] = None,
    val: Optional[List] = None,
    val_labels: Optional[List] = None,
    test: Optional[List] = None,
    test_labels: Optional[List] = None,
    val_percent: float = 0.2,
    test_percent: float = 0.2,
    batch_size: int = 64,
    batch_size_eval: int = None,
    seed: int = None,
    num_workers: int = 1,
    pin_memory: bool = True,
    drop_last: bool = False,
    shuffle_eval: bool = False,
    sampler_train: Sampler = None,
    sampler_val: Sampler = None,
    sampler_test: Sampler = None,
    batch_sampler_train: BatchSampler = None,
    batch_sampler_val: BatchSampler = None,
    batch_sampler_test: BatchSampler = None,
    collate_fn: Optional[Callable[..., Any]] = None,
    language_model: bool = False,
    tokenizer: str = "spacy",
    no_test_set: bool = False,
    **corpus_args,
):
    """Wrap raw corpus in a LightningDataModule

    * This handles the selection of the appropriate corpus class based on the tokenizer argument.
    * If language_model=True it uses the appropriate dataset from slp.data.datasets.
    * Uses the PLDataModuleFromDatasets to split the val and test sets if not provided

    Args:
        train (List): Raw train corpus
        train_labels (Optional[List]): Train labels. Defaults to None.
        val (Optional[List]): Raw validation corpus. Defaults to None.
        val_labels (Optional[List]): Validation labels. Defaults to None.
        test (Optional[List]): Raw test corpus. Defaults to None.
        test_labels (Optional[List]): Test labels. Defaults to None.
        val_percent (float): Percent of train to be used for validation if no validation set is given. Defaults to 0.2.
        test_percent (float): Percent of train to be used for test set if no test set is given. Defaults to 0.2.
        batch_size (int): Training batch size. Defaults to 1.
        batch_size_eval (Optional[int]): Validation and test batch size. Defaults to None.
        seed (Optional[int]): Seed for deterministic run. Defaults to None.
        num_workers (int): Number of workers in the DataLoader. Defaults to 1.
        pin_memory (bool): Pin tensors to GPU memory. Defaults to True.
        drop_last (bool): Drop last incomplete batch. Defaults to False.
        sampler_train (Sampler): Sampler for train loader. Defaults to None.
        sampler_val (Sampler): Sampler for validation loader. Defaults to None.
        sampler_test (Sampler): Sampler for test loader. Defaults to None.
        batch_sampler_train (BatchSampler): Batch sampler for train loader. Defaults to None.
        batch_sampler_val (BatchSampler): Batch sampler for validation loader. Defaults to None.
        batch_sampler_test (BatchSampler): Batch sampler for test loader. Defaults to None.
        shuffle_eval (bool): Shuffle validation and test dataloaders. Defaults to False.
        collate_fn (Callable[..., Any]): Collator function. Defaults to None.
        language_model (bool): Use corpus for Language Modeling. Defaults to False.
        tokenizer (str): Select one of the cls.accepted_tokenizers. Defaults to "spacy".
        no_test_set (bool): Do not create test set. Useful for tuning
        **corpus_args (kwargs): Extra arguments to be passed to the corpus. See
            slp/data/corpus.py
    Raises:
        ValueError: [description]
        ValueError: [description]
    """
    self.language_model = language_model
    self.tokenizer = tokenizer
    self.corpus_args = corpus_args

    train_data, val_data, test_data = self._zip_corpus_and_labels(
        train, val, test, train_labels, val_labels, test_labels
    )

    self.no_test_set = no_test_set
    super(PLDataModuleFromCorpus, self).__init__(
        train_data,  # type: ignore
        val=val_data,  # type: ignore
        test=test_data,  # type: ignore
        val_percent=val_percent,
        test_percent=test_percent,
        batch_size=batch_size,
        batch_size_eval=batch_size_eval,
        seed=seed,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=drop_last,
        shuffle_eval=shuffle_eval,
        sampler_train=sampler_train,
        sampler_val=sampler_val,
        sampler_test=sampler_test,
        batch_sampler_train=batch_sampler_train,
        batch_sampler_val=batch_sampler_val,
        batch_sampler_test=batch_sampler_test,
        collate_fn=collate_fn,
        no_test_set=no_test_set,
    )

add_argparse_args(parent_parser) classmethod

Augment input parser with arguments for data loading and corpus processing

Parameters:

Name Type Description Default
parent_parser argparse.ArgumentParser

Parser created by the user

required

Returns:

Type Description
argparse.ArgumentParser

Augmented parser

Source code in slp/plbind/dm.py
@classmethod
def add_argparse_args(cls, parent_parser):
    """Augment input parser with arguments for data loading and corpus processing

    Args:
        parent_parser (argparse.ArgumentParser): Parser created by the user

    Returns:
        argparse.ArgumentParser: Augmented parser
    """
    parser = super(PLDataModuleFromCorpus, cls).add_argparse_args(parent_parser)
    parser.add_argument(
        "--tokenizer",
        dest="data.tokenizer",
        type=str.lower,
        # Corpus can already be tokenized, you can use spacy for word tokenization or any tokenizer from hugging face
        choices=cls.accepted_tokenizers,
        default="spacy",
        help="Token type. The tokenization will happen at this level.",
    )

    # Only when tokenizer == spacy
    parser.add_argument(
        "--limit-vocab",
        dest="data.limit_vocab_size",
        type=int,
        default=-1,
        help="Limit vocab size. -1 means use the whole vocab. Applicable only when --tokenizer=spacy",
    )

    parser.add_argument(
        "--embeddings-file",
        dest="data.embeddings_file",
        type=dir_path,
        default=None,
        help="Path to file with pretrained embeddings. Applicable only when --tokenizer=spacy",
    )

    parser.add_argument(
        "--embeddings-dim",
        dest="data.embeddings_dim",
        type=int,
        default=50,
        help="Embedding dim of pretrained embeddings. Applicable only when --tokenizer=spacy",
    )

    parser.add_argument(
        "--lang",
        dest="data.lang",
        type=str,
        default="en_core_web_md",
        help="Language for spacy tokenizer, e.g. en_core_web_md. Applicable only when --tokenizer=spacy",
    )

    parser.add_argument(
        "--no-add-specials",
        dest="data.add_special_tokens",
        action="store_false",
        help="Do not add special tokens for hugging face tokenizers",
    )

    # Generic args
    parser.add_argument(
        "--lower",
        dest="data.lower",
        action="store_true",
        help="Convert to lowercase.",
    )

    parser.add_argument(
        "--prepend-bos",
        dest="data.prepend_bos",
        action="store_true",
        help="Prepend [BOS] token",
    )

    parser.add_argument(
        "--append-eos",
        dest="data.append_eos",
        action="store_true",
        help="Append [EOS] token",
    )

    parser.add_argument(
        "--max-sentence-length",
        dest="data.max_len",
        type=int,
        default=-1,
        help="Maximum allowed sentence length. -1 means use the whole sentence",
    )

    return parser

PLDataModuleFromDatasets

__init__(self, train, val=None, test=None, val_percent=0.2, test_percent=0.2, batch_size=1, batch_size_eval=None, seed=None, num_workers=1, pin_memory=True, drop_last=False, sampler_train=None, sampler_val=None, sampler_test=None, batch_sampler_train=None, batch_sampler_val=None, batch_sampler_test=None, shuffle_eval=False, collate_fn=None, no_test_set=False) special

LightningDataModule wrapper for generic torch.utils.data.Dataset

If val or test Datasets are not provided, this class will split val_pecent and test_percent of the train set respectively to create them

Parameters:

Name Type Description Default
train Dataset

Train set

required
val Dataset

Validation set. Defaults to None.

None
test Dataset

Test set. Defaults to None.

None
val_percent float

Percent of train to be used for validation if no validation set is given. Defaults to 0.2.

0.2
test_percent float

Percent of train to be used for test set if no test set is given. Defaults to 0.2.

0.2
batch_size int

Training batch size. Defaults to 1.

1
batch_size_eval Optional[int]

Validation and test batch size. Defaults to None.

None
seed Optional[int]

Seed for deterministic run. Defaults to None.

None
num_workers int

Number of workers in the DataLoader. Defaults to 1.

1
pin_memory bool

Pin tensors to GPU memory. Defaults to True.

True
drop_last bool

Drop last incomplete batch. Defaults to False.

False
sampler_train Sampler

Sampler for train loader. Defaults to None.

None
sampler_val Sampler

Sampler for validation loader. Defaults to None.

None
sampler_test Sampler

Sampler for test loader. Defaults to None.

None
batch_sampler_train BatchSampler

Batch sampler for train loader. Defaults to None.

None
batch_sampler_val BatchSampler

Batch sampler for validation loader. Defaults to None.

None
batch_sampler_test BatchSampler

Batch sampler for test loader. Defaults to None.

None
shuffle_eval bool

Shuffle validation and test dataloaders. Defaults to False.

False
collate_fn Optional[Callable[..., Any]]

Collator function. Defaults to None.

None
no_test_set bool

Do not create test set. Useful for tuning

False

Exceptions:

Type Description
ValueError

If both mutually exclusive sampler_train and batch_sampler_train are provided

ValueError

If both mutually exclusive sampler_val and batch_sampler_val are provided

ValueError

If both mutually exclusive sampler_test and batch_sampler_test are provided

Source code in slp/plbind/dm.py
def __init__(
    self,
    train: Dataset,
    val: Dataset = None,
    test: Dataset = None,
    val_percent: float = 0.2,
    test_percent: float = 0.2,
    batch_size: int = 1,
    batch_size_eval: Optional[int] = None,
    seed: Optional[int] = None,
    num_workers: int = 1,
    pin_memory: bool = True,
    drop_last: bool = False,
    sampler_train: Sampler = None,
    sampler_val: Sampler = None,
    sampler_test: Sampler = None,
    batch_sampler_train: BatchSampler = None,
    batch_sampler_val: BatchSampler = None,
    batch_sampler_test: BatchSampler = None,
    shuffle_eval: bool = False,
    collate_fn: Optional[Callable[..., Any]] = None,
    no_test_set: bool = False,
):
    """LightningDataModule wrapper for generic torch.utils.data.Dataset

    If val or test Datasets are not provided, this class will split
    val_pecent and test_percent of the train set respectively to create them

    Args:
        train (Dataset): Train set
        val (Dataset): Validation set. Defaults to None.
        test (Dataset): Test set. Defaults to None.
        val_percent (float): Percent of train to be used for validation if no validation set is given. Defaults to 0.2.
        test_percent (float): Percent of train to be used for test set if no test set is given. Defaults to 0.2.
        batch_size (int): Training batch size. Defaults to 1.
        batch_size_eval (Optional[int]): Validation and test batch size. Defaults to None.
        seed (Optional[int]): Seed for deterministic run. Defaults to None.
        num_workers (int): Number of workers in the DataLoader. Defaults to 1.
        pin_memory (bool): Pin tensors to GPU memory. Defaults to True.
        drop_last (bool): Drop last incomplete batch. Defaults to False.
        sampler_train (Sampler): Sampler for train loader. Defaults to None.
        sampler_val (Sampler): Sampler for validation loader. Defaults to None.
        sampler_test (Sampler): Sampler for test loader. Defaults to None.
        batch_sampler_train (BatchSampler): Batch sampler for train loader. Defaults to None.
        batch_sampler_val (BatchSampler): Batch sampler for validation loader. Defaults to None.
        batch_sampler_test (BatchSampler): Batch sampler for test loader. Defaults to None.
        shuffle_eval (bool): Shuffle validation and test dataloaders. Defaults to False.
        collate_fn (Callable[..., Any]): Collator function. Defaults to None.
        no_test_set (bool): Do not create test set. Useful for tuning

    Raises:
        ValueError: If both mutually exclusive sampler_train and batch_sampler_train are provided
        ValueError: If both mutually exclusive sampler_val and batch_sampler_val are provided
        ValueError: If both mutually exclusive sampler_test and batch_sampler_test are provided
    """
    super(PLDataModuleFromDatasets, self).__init__()
    self.setup_has_run = False
    if batch_sampler_train is not None and sampler_train is not None:
        raise ValueError(
            "You provided both a sampler and a batch sampler for the train set. These are mutually exclusive"
        )

    if batch_sampler_val is not None and sampler_val is not None:
        raise ValueError(
            "You provided both a sampler and a batch sampler for the validation set. These are mutually exclusive"
        )

    if batch_sampler_test is not None and sampler_test is not None:
        raise ValueError(
            "You provided both a sampler and a batch sampler for the test set. These are mutually exclusive"
        )
    self.val_percent = val_percent
    self.test_percent = test_percent
    self.sampler_train = sampler_train
    self.sampler_val = sampler_val
    self.sampler_test = sampler_test
    self.batch_sampler_train = batch_sampler_train
    self.batch_sampler_val = batch_sampler_val
    self.batch_sampler_test = batch_sampler_test
    self.num_workers = num_workers
    self.pin_memory = pin_memory
    self.drop_last = drop_last

    self.shuffle_eval = shuffle_eval
    self.collate_fn = collate_fn

    self.batch_size = batch_size
    self.seed = seed

    if batch_size_eval is None:
        batch_size_eval = self.batch_size

    self.no_test_set = no_test_set
    self.batch_size_eval = batch_size_eval
    self.train = train
    self.val = val
    self.test = test

add_argparse_args(parent_parser) classmethod

Augment input parser with arguments for data loading

Parameters:

Name Type Description Default
parent_parser ArgumentParser

Parser created by the user

required

Returns:

Type Description
ArgumentParser

argparse.ArgumentParser: Augmented parser

Source code in slp/plbind/dm.py
@classmethod
def add_argparse_args(
    cls, parent_parser: argparse.ArgumentParser
) -> argparse.ArgumentParser:
    """Augment input parser with arguments for data loading

    Args:
        parent_parser (argparse.ArgumentParser): Parser created by the user

    Returns:
        argparse.ArgumentParser: Augmented parser
    """
    parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
    parser.add_argument(
        "--val-percent",
        dest="data.val_percent",
        type=float,
        default=0.2,
        help="Percent of validation data to be randomly split from the training set, if no validation set is provided",
    )

    parser.add_argument(
        "--test-percent",
        dest="data.test_percent",
        type=float,
        default=0.2,
        help="Percent of test data to be randomly split from the training set, if no test set is provided",
    )

    parser.add_argument(
        "--bsz",
        dest="data.batch_size",
        type=int,
        default=32,
        help="Training batch size",
    )

    parser.add_argument(
        "--bsz-eval",
        dest="data.batch_size_eval",
        type=int,
        default=32,
        help="Evaluation batch size",
    )

    parser.add_argument(
        "--num-workers",
        dest="data.num_workers",
        type=int,
        default=1,
        help="Number of workers to be used in the DataLoader",
    )

    parser.add_argument(
        "--no-pin-memory",
        dest="data.pin_memory",
        action="store_false",
        help="Don't pin data to GPU memory when transferring",
    )

    parser.add_argument(
        "--drop-last",
        dest="data.drop_last",
        action="store_true",
        help="Drop last incomplete batch",
    )

    parser.add_argument(
        "--no-shuffle-eval",
        dest="data.shuffle_eval",
        action="store_false",
        help="Don't shuffle val & test sets",
    )

    return parser

prepare_data(self)

Use this to download and prepare data.

.. warning:: DO NOT set state to the model (use setup instead) since this is NOT called on every GPU in DDP/TPU

Example::

def prepare_data(self):
    # good
    download_data()
    tokenize()
    etc()

    # bad
    self.split = data_split
    self.some_state = some_other_state()

In DDP prepare_data can be called in two ways (using Trainer(prepare_data_per_node)):

  1. Once per node. This is the default and is only called on LOCAL_RANK=0.
  2. Once in total. Only called on GLOBAL_RANK=0.

Example::

# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
Trainer(prepare_data_per_node=True)

# call on GLOBAL_RANK=0 (great for shared file systems)
Trainer(prepare_data_per_node=False)

This is called before requesting the dataloaders:

.. code-block:: python

model.prepare_data()
    if ddp/tpu: init()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
Source code in slp/plbind/dm.py
def prepare_data(self):
    return None

test_dataloader(self)

Configure test DataLoader

Returns:

Type Description
DataLoader

Pytorch DataLoader for test set

Source code in slp/plbind/dm.py
def test_dataloader(self):
    """Configure test DataLoader

    Returns:
        DataLoader: Pytorch DataLoader for test set
    """

    return DataLoader(
        self.test,
        batch_size=self.batch_size_eval if self.batch_sampler_test is None else 1,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        drop_last=self.drop_last and (self.batch_sampler_test is None),
        sampler=self.sampler_test,
        batch_sampler=self.batch_sampler_test,
        shuffle=(
            self.shuffle_eval
            and (self.batch_sampler_test is None)
            and (self.sampler_test is None)
        ),
        collate_fn=self.collate_fn,
    )

train_dataloader(self)

Configure train DataLoader

Returns:

Type Description
DataLoader

DataLoader: Pytorch DataLoader for train set

Source code in slp/plbind/dm.py
def train_dataloader(self) -> DataLoader:
    """Configure train DataLoader

    Returns:
        DataLoader: Pytorch DataLoader for train set
    """

    return DataLoader(
        self.train,
        batch_size=self.batch_size if self.batch_sampler_train is None else 1,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        drop_last=self.drop_last and (self.batch_sampler_train is None),
        sampler=self.sampler_train,
        batch_sampler=self.batch_sampler_train,
        shuffle=(self.batch_sampler_train is None) and (self.sampler_train is None),
        collate_fn=self.collate_fn,
    )

val_dataloader(self)

Configure validation DataLoader

Returns:

Type Description
DataLoader

Pytorch DataLoader for validation set

Source code in slp/plbind/dm.py
def val_dataloader(self):
    """Configure validation DataLoader

    Returns:
        DataLoader: Pytorch DataLoader for validation set
    """
    val = DataLoader(
        self.val,
        batch_size=self.batch_size_eval if self.batch_sampler_val is None else 1,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        drop_last=self.drop_last and (self.batch_sampler_val is None),
        sampler=self.sampler_val,
        batch_sampler=self.batch_sampler_val,
        shuffle=(
            self.shuffle_eval
            and (self.batch_sampler_val is None)
            and (self.sampler_val is None)
        ),
        collate_fn=self.collate_fn,
    )

    return val

split_data(dataset, test_size, seed)

Train-test split of dataset.

Dataset can be either a torch.utils.data.Dataset or a list

Parameters:

Name Type Description Default
dataset Union[Dataset, List]

Input dataset

required
test_size float

Size of the test set. Defaults to 0.2.

required
seed int

Optional seed for deterministic run. Defaults to None.

required

Returns:

Type Description
Tuple[Union[Dataset, List], Union[Dataset, List]

(train set, test set)

Source code in slp/plbind/dm.py
def split_data(dataset, test_size, seed):
    """Train-test split of dataset.

    Dataset can be either a torch.utils.data.Dataset or a list

    Args:
        dataset (Union[Dataset, List]): Input dataset
        test_size (float): Size of the test set. Defaults to 0.2.
        seed (int): Optional seed for deterministic run. Defaults to None.

    Returns:
        Tuple[Union[Dataset, List], Union[Dataset, List]: (train set, test set)
    """
    train, test = None, None

    if isinstance(dataset, torch.utils.data.Dataset):
        test_len = int(test_size * len(dataset))
        train_len = len(dataset) - test_len

        seed_generator = None

        if seed is not None:
            seed_generator = torch.Generator().manual_seed(seed)

        train, test = random_split(
            dataset, [train_len, test_len], generator=seed_generator
        )

    else:

        train, test = train_test_split(dataset, test_size=test_size, random_state=seed)

    return train, test

FixedWandbLogger

__init__(self, name=None, save_dir=None, offline=False, id=None, anonymous=False, version=None, project=None, log_model=False, experiment=None, prefix='', sync_step=True, checkpoint_dir=None, **kwargs) special

Wandb logger fix to save checkpoints in wandb

Accepts an additional checkpoint_dir argument, pointing to the real checkpoint directory

Parameters:

Name Type Description Default
name Optional[str]

Display name for the run. Defaults to None.

None
save_dir Optional[str]

Path where data is saved. Defaults to None.

None
offline Optional[bool]

Run offline (data can be streamed later to wandb servers). Defaults to False.

False
id Optional[str]

Sets the version, mainly used to resume a previous run. Defaults to None.

None
anonymous Optional[bool]

Enables or explicitly disables anonymous logging. Defaults to False.

False
version Optional[str]

Sets the version, mainly used to resume a previous run. Defaults to None.

None
project Optional[str]

The name of the project to which this run will belong. Defaults to None.

None
log_model Optional[bool]

Save checkpoints in wandb dir to upload on W&B servers. Defaults to False.

False
experiment Run

WandB experiment object. Defaults to None.

None
prefix Optional[str]

A string to put at the beginning of metric keys. Defaults to "".

''
sync_step Optional[bool]

Sync Trainer step with wandb step. Defaults to True.

True
checkpoint_dir Optional[str]

Real checkpoint dir. Defaults to None.

None
Source code in slp/plbind/helpers.py
def __init__(
    self,
    name: Optional[str] = None,
    save_dir: Optional[str] = None,
    offline: Optional[bool] = False,
    id: Optional[str] = None,
    anonymous: Optional[bool] = False,
    version: Optional[str] = None,
    project: Optional[str] = None,
    log_model: Optional[bool] = False,
    experiment: wandb.sdk.wandb_run.Run = None,
    prefix: Optional[str] = "",
    sync_step: Optional[bool] = True,
    checkpoint_dir: Optional[str] = None,
    **kwargs,
):
    """Wandb logger fix to save checkpoints in wandb

    Accepts an additional checkpoint_dir argument, pointing to the real checkpoint directory

    Args:
        name (Optional[str]): Display name for the run. Defaults to None.
        save_dir (Optional[str]): Path where data is saved. Defaults to None.
        offline (Optional[bool]): Run offline (data can be streamed later to wandb servers). Defaults to False.
        id (Optional[str]): Sets the version, mainly used to resume a previous run. Defaults to None.
        anonymous (Optional[bool]): Enables or explicitly disables anonymous logging. Defaults to False.
        version (Optional[str]): Sets the version, mainly used to resume a previous run. Defaults to None.
        project (Optional[str]): The name of the project to which this run will belong. Defaults to None.
        log_model (Optional[bool]): Save checkpoints in wandb dir to upload on W&B servers. Defaults to False.
        experiment ([type]): WandB experiment object. Defaults to None.
        prefix (Optional[str]): A string to put at the beginning of metric keys. Defaults to "".
        sync_step (Optional[bool]): Sync Trainer step with wandb step. Defaults to True.
        checkpoint_dir (Optional[str]): Real checkpoint dir. Defaults to None.
    """
    self._checkpoint_dir = checkpoint_dir
    super(FixedWandbLogger, self).__init__(
        name=name,
        save_dir=save_dir,
        offline=offline,
        id=id,
        anonymous=anonymous,
        version=version,
        project=project,
        log_model=log_model,
        experiment=experiment,
        prefix=prefix,
        sync_step=sync_step,
        **kwargs,
    )

finalize(self, status)

Determine where checkpoints are saved and upload to wandb servers

Parameters:

Name Type Description Default
status str

Experiment status

required
Source code in slp/plbind/helpers.py
@rank_zero_only
def finalize(self, status: str) -> None:
    """Determine where checkpoints are saved and upload to wandb servers

    Args:
        status (str): Experiment status
    """
    # offset future training logged on same W&B run

    if self._experiment is not None:
        self._step_offset = self._experiment.step

    checkpoint_dir = (
        self._checkpoint_dir if self._checkpoint_dir is not None else self.save_dir
    )

    if checkpoint_dir is None:
        logger.warning(
            "Invalid checkpoint dir. Checkpoints will not be uploaded to Wandb."
        )
        logger.info(
            "You can manually upload your checkpoints through the CLI interface."
        )

    else:
        # upload all checkpoints from saving dir

        if self._log_model:
            wandb.save(os.path.join(checkpoint_dir, "*.ckpt"))

FromLogits

__init__(self, metric) special

Wrap pytorch lighting metric to accept logits input

Parameters:

Name Type Description Default
metric Metric

The metric to wrap, e.g. pl.metrics.Accuracy

required
Source code in slp/plbind/helpers.py
def __init__(self, metric: pl.metrics.Metric):
    """Wrap pytorch lighting metric to accept logits input

    Args:
        metric (pl.metrics.Metric): The metric to wrap, e.g. pl.metrics.Accuracy
    """
    super(FromLogits, self).__init__(
        compute_on_step=metric.compute_on_step,
        dist_sync_on_step=metric.dist_sync_on_step,
        process_group=metric.process_group,
        dist_sync_fn=metric.dist_sync_fn,
    )
    self.metric = metric

compute(self)

Compute metric

Returns:

Type Description
Tensor

torch.Tensor: metric value

Source code in slp/plbind/helpers.py
def compute(self) -> torch.Tensor:
    """Compute metric

    Returns:
        torch.Tensor: metric value
    """
    return self.metric.compute()  # type: ignore

update(self, preds, target)

Update underlying metric

Calculate softmax under the hood and pass probs to the underlying metric

Parameters:

Name Type Description Default
preds Tensor

[B, *, num_classes] Logits

required
target Tensor

[B, *] Ground truths

required
Source code in slp/plbind/helpers.py
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:  # type: ignore
    """Update underlying metric

    Calculate softmax under the hood and pass probs to the underlying metric

    Args:
        preds (torch.Tensor): [B, *, num_classes] Logits
        target (torch.Tensor): [B, *] Ground truths
    """
    preds = F.softmax(preds, dim=-1)
    self.metric.update(preds, target)  # type: ignore

AutoEncoderPLModule

__init__(self, model, optimizer, criterion, lr_scheduler=None, hparams=None, metrics=None, calculate_perplexity=False) special

Pass arguments through to base class

Source code in slp/plbind/module.py
def __init__(
    self,
    model: nn.Module,
    optimizer: Union[Optimizer, List[Optimizer]],
    criterion: LossType,
    lr_scheduler: Union[_LRScheduler, List[_LRScheduler]] = None,
    hparams: Configuration = None,
    metrics: Optional[Dict[str, pl.metrics.Metric]] = None,
    calculate_perplexity=False,
):
    """Pass arguments through to base class"""
    super(AutoEncoderPLModule, self).__init__(
        model,
        optimizer,
        criterion,
        predictor_cls=_AutoEncoder,
        lr_scheduler=lr_scheduler,
        hparams=hparams,
        metrics=metrics,
        calculate_perplexity=calculate_perplexity,
    )

BertPLModule

__init__(self, model, optimizer, criterion, lr_scheduler=None, hparams=None, metrics=None, calculate_perplexity=False) special

Pass arguments through to base class

Source code in slp/plbind/module.py
def __init__(
    self,
    model: nn.Module,
    optimizer: Union[Optimizer, List[Optimizer]],
    criterion: LossType,
    lr_scheduler: Union[_LRScheduler, List[_LRScheduler]] = None,
    hparams: Configuration = None,
    metrics: Optional[Dict[str, pl.metrics.Metric]] = None,
    calculate_perplexity=False,
):
    """Pass arguments through to base class"""
    super(BertPLModule, self).__init__(
        model,
        optimizer,
        criterion,
        predictor_cls=_BertSequenceClassification,
        lr_scheduler=lr_scheduler,
        hparams=hparams,
        metrics=metrics,
        calculate_perplexity=calculate_perplexity,
    )

MultimodalTransformerClassificationPLModule

__init__(self, model, optimizer, criterion, lr_scheduler=None, hparams=None, metrics=None, calculate_perplexity=False) special

Pass arguments through to base class

Source code in slp/plbind/module.py
def __init__(
    self,
    model: nn.Module,
    optimizer: Union[Optimizer, List[Optimizer]],
    criterion: LossType,
    lr_scheduler: Union[_LRScheduler, List[_LRScheduler]] = None,
    hparams: Configuration = None,
    metrics: Optional[Dict[str, pl.metrics.Metric]] = None,
    calculate_perplexity=False,
):
    """Pass arguments through to base class"""
    super(MultimodalTransformerClassificationPLModule, self).__init__(
        model,
        optimizer,
        criterion,
        predictor_cls=_MultimodalTransformerClassification,
        lr_scheduler=lr_scheduler,
        hparams=hparams,
        metrics=metrics,
        calculate_perplexity=calculate_perplexity,
    )

PLModule

__init__(self, model, optimizer, criterion, lr_scheduler=None, hparams=None, metrics=None, calculate_perplexity=False) special

Pass arguments through to base class

Source code in slp/plbind/module.py
def __init__(
    self,
    model: nn.Module,
    optimizer: Union[Optimizer, List[Optimizer]],
    criterion: LossType,
    lr_scheduler: Union[_LRScheduler, List[_LRScheduler]] = None,
    hparams: Configuration = None,
    metrics: Optional[Dict[str, pl.metrics.Metric]] = None,
    calculate_perplexity=False,
):
    """Pass arguments through to base class"""
    super(PLModule, self).__init__(
        model,
        optimizer,
        criterion,
        predictor_cls=_Classification,
        lr_scheduler=lr_scheduler,
        hparams=hparams,
        metrics=metrics,
        calculate_perplexity=calculate_perplexity,
    )

RnnPLModule

__init__(self, model, optimizer, criterion, lr_scheduler=None, hparams=None, metrics=None, calculate_perplexity=False) special

Pass arguments through to base class

Source code in slp/plbind/module.py
def __init__(
    self,
    model: nn.Module,
    optimizer: Union[Optimizer, List[Optimizer]],
    criterion: LossType,
    lr_scheduler: Union[_LRScheduler, List[_LRScheduler]] = None,
    hparams: Configuration = None,
    metrics: Optional[Dict[str, pl.metrics.Metric]] = None,
    calculate_perplexity=False,
):
    """Pass arguments through to base class"""
    super(RnnPLModule, self).__init__(
        model,
        optimizer,
        criterion,
        predictor_cls=_RnnClassification,
        lr_scheduler=lr_scheduler,
        hparams=hparams,
        metrics=metrics,
        calculate_perplexity=calculate_perplexity,
    )

SimplePLModule

__init__(self, model, optimizer, criterion, lr_scheduler=None, hparams=None, metrics=None, predictor_cls=<class 'slp.plbind.module._Classification'>, calculate_perplexity=False) special

Wraps a (model, optimizer, criterion, lr_scheduler) tuple in a LightningModule

Handles the boilerplate for metrics calculation and logging and defines the train_step / val_step / test_step with use of the predictor helper classes (e.g. _Classification, _RnnClassification)

Parameters:

Name Type Description Default
model Module

Module to use for prediction

required
optimizer Union[torch.optim.optimizer.Optimizer, List[torch.optim.optimizer.Optimizer]]

Optimizers to use for training

required
criterion Union[torch.nn.modules.module.Module, Callable]

Task loss

required
lr_scheduler Union[torch.optim.lr_scheduler._LRScheduler, List[torch.optim.lr_scheduler._LRScheduler]]

Learning rate scheduler. Defaults to None.

None
hparams Union[omegaconf.dictconfig.DictConfig, Dict[str, Any], argparse.Namespace]

Hyperparameter values. This ensures they are logged with trainer.loggers. Defaults to None.

None
metrics Optional[Dict[str, pytorch_lightning.metrics.metric.Metric]]

Metrics to track. Defaults to None.

None
predictor_cls [type]

Class that defines a parse_batch and a get_predictions_and_targets method. Defaults to _Classification.

<class 'slp.plbind.module._Classification'>
calculate_perplexity bool

Whether to calculate perplexity. Would be cleaner as a metric, but this is more efficient. Defaults to False.

False
Source code in slp/plbind/module.py
def __init__(
    self,
    model: nn.Module,
    optimizer: Union[Optimizer, List[Optimizer]],
    criterion: LossType,
    lr_scheduler: Union[_LRScheduler, List[_LRScheduler]] = None,
    hparams: Configuration = None,
    metrics: Optional[Dict[str, pl.metrics.Metric]] = None,
    predictor_cls=_Classification,
    calculate_perplexity: bool = False,  # for LM. Dirty but much more efficient
):
    """Wraps a (model, optimizer, criterion, lr_scheduler) tuple in a LightningModule

    Handles the boilerplate for metrics calculation and logging and defines the train_step / val_step / test_step
    with use of the predictor helper classes (e.g. _Classification, _RnnClassification)

    Args:
        model (nn.Module): Module to use for prediction
        optimizer (Union[Optimizer, List[Optimizer]]): Optimizers to use for training
        criterion (LossType): Task loss
        lr_scheduler (Union[_LRScheduler, List[_LRScheduler]], optional): Learning rate scheduler. Defaults to None.
        hparams (Configuration, optional): Hyperparameter values. This ensures they are logged with trainer.loggers. Defaults to None.
        metrics (Optional[Dict[str, pl.metrics.Metric]], optional): Metrics to track. Defaults to None.
        predictor_cls ([type], optional): Class that defines a parse_batch and a
                get_predictions_and_targets method. Defaults to _Classification.
        calculate_perplexity (bool, optional): Whether to calculate perplexity.
                Would be cleaner as a metric, but this is more efficient. Defaults to False.
    """
    super(SimplePLModule, self).__init__()
    self.calculate_perplexity = calculate_perplexity
    self.model = model
    self.optimizer = optimizer
    self.lr_scheduler = lr_scheduler
    self.criterion = criterion

    if metrics is not None:
        self.train_metrics = nn.ModuleDict(metrics)
        self.val_metrics = nn.ModuleDict({k: v.clone() for k, v in metrics.items()})
        self.test_metrics = nn.ModuleDict(
            {k: v.clone() for k, v in metrics.items()}
        )
    else:
        self.train_metrics = nn.ModuleDict(modules=None)
        self.val_metrics = nn.ModuleDict(modules=None)
        self.test_metrics = nn.ModuleDict(modules=None)
    self.predictor = predictor_cls()

    if hparams is not None:
        if isinstance(hparams, Namespace):
            dict_params = vars(hparams)
        elif isinstance(hparams, DictConfig):
            dict_params = cast(Dict[str, Any], OmegaConf.to_container(hparams))
        else:
            dict_params = hparams
        # self.hparams = dict_params
        self.save_hyperparameters(dict_params)

aggregate_epoch_metrics(self, outputs, mode='Training')

Aggregate metrics over a whole epoch

Parameters:

Name Type Description Default
outputs List[Dict[str, torch.Tensor]]

Aggregated outputs from train_step, validation_step or test_step

required
mode str

"Training", "Validation" or "Testing". Defaults to "Training".

'Training'
Source code in slp/plbind/module.py
def aggregate_epoch_metrics(self, outputs, mode="Training"):
    """Aggregate metrics over a whole epoch

    Args:
        outputs (List[Dict[str, torch.Tensor]]): Aggregated outputs from train_step, validation_step or test_step
        mode (str, optional): "Training", "Validation" or "Testing". Defaults to "Training".
    """

    def fmt(name):
        """Format metric name"""

        return f"{name}" if name != "loss" else "train_loss"

    keys = list(outputs[0].keys())
    aggregated = {fmt(k): torch.stack([x[k] for x in outputs]).mean() for k in keys}
    aggregated["epoch"] = self.current_epoch + 1

    self.log_dict(aggregated, logger=True, prog_bar=False, on_epoch=True)

    return aggregated

configure_optimizers(self)

Return optimizers and learning rate schedulers

Returns:

Type Description
Tuple[List[Optimizer], List[_LRScheduler]]

(optimizers, lr_schedulers)

Source code in slp/plbind/module.py
def configure_optimizers(self):
    """Return optimizers and learning rate schedulers

    Returns:
        Tuple[List[Optimizer], List[_LRScheduler]]: (optimizers, lr_schedulers)
    """

    if self.lr_scheduler is not None:
        scheduler = {
            "scheduler": self.lr_scheduler,
            "interval": "epoch",
            "monitor": "val_loss",
        }

        return [self.optimizer], [scheduler]

    return self.optimizer

forward(self, *args, **kwargs)

Call wrapped module forward

Source code in slp/plbind/module.py
def forward(self, *args, **kwargs):
    """Call wrapped module forward"""

    return self.model(*args, **kwargs)

log_to_console(self, metrics, mode='Training')

Log metrics to console

Parameters:

Name Type Description Default
metrics Dict[str, torch.Tensor]

Computed metrics

required
mode str

"Training", "Validation" or "Testing". Defaults to "Training".

'Training'
Source code in slp/plbind/module.py
def log_to_console(self, metrics, mode="Training"):
    """Log metrics to console

    Args:
        metrics (Dict[str, torch.Tensor]): Computed metrics
        mode (str, optional): "Training", "Validation" or "Testing". Defaults to "Training".
    """
    logger.info("Epoch {} {} results".format(self.current_epoch + 1, mode))
    print_separator(symbol="-", n=50, print_fn=logger.info)

    for name, value in metrics.items():
        if name == "epoch":
            continue
        logger.info("{:<15} {:<15}".format(name, value))

    print_separator(symbol="%", n=50, print_fn=logger.info)

test_epoch_end(self, outputs)

Aggregate metrics of a test epoch

Parameters:

Name Type Description Default
outputs List[Dict[str, torch.Tensor]]

Aggregated outputs from test_step

required
Source code in slp/plbind/module.py
def test_epoch_end(self, outputs):
    """Aggregate metrics of a test epoch

    Args:
        outputs (List[Dict[str, torch.Tensor]]): Aggregated outputs from test_step
    """
    outputs = self.aggregate_epoch_metrics(outputs, mode="Test")
    self.log_to_console(outputs, mode="Test")

test_step(self, batch, batch_idx)

Compute loss for a single test step and log metrics to loggers

Parameters:

Name Type Description Default
batch Tuple[torch.Tensor, ...]

Input batch

required
batch_idx int

Index of batch

required

Returns:

Type Description
Dict[str, torch.Tensor]

computed metrics

Source code in slp/plbind/module.py
def test_step(self, batch, batch_idx):
    """Compute loss for a single test step and log metrics to loggers

    Args:
        batch (Tuple[torch.Tensor, ...]): Input batch
        batch_idx (int): Index of batch

    Returns:
        Dict[str, torch.Tensor]: computed metrics
    """
    y_hat, targets = self.predictor.get_predictions_and_targets(self, batch)
    loss = self.criterion(y_hat, targets)
    metrics = self._compute_metrics(
        self.test_metrics, loss, y_hat, targets, mode="test"
    )

    return metrics

training_epoch_end(self, outputs)

Aggregate metrics of a training epoch

Parameters:

Name Type Description Default
outputs List[Dict[str, torch.Tensor]]

Aggregated outputs from train_step

required
Source code in slp/plbind/module.py
def training_epoch_end(self, outputs):
    """Aggregate metrics of a training epoch

    Args:
        outputs (List[Dict[str, torch.Tensor]]): Aggregated outputs from train_step
    """
    outputs = self.aggregate_epoch_metrics(outputs, mode="Training")
    self.log_to_console(outputs, mode="Training")

training_step(self, batch, batch_idx)

Compute loss for a single training step and log metrics to loggers

Parameters:

Name Type Description Default
batch Tuple[torch.Tensor, ...]

Input batch

required
batch_idx int

Index of batch

required

Returns:

Type Description
Dict[str, torch.Tensor]

computed metrics

Source code in slp/plbind/module.py
def training_step(self, batch, batch_idx):
    """Compute loss for a single training step and log metrics to loggers

    Args:
        batch (Tuple[torch.Tensor, ...]): Input batch
        batch_idx (int): Index of batch

    Returns:
        Dict[str, torch.Tensor]: computed metrics
    """
    y_hat, targets = self.predictor.get_predictions_and_targets(self.model, batch)
    loss = self.criterion(y_hat, targets)
    metrics = self._compute_metrics(
        self.train_metrics, loss, y_hat, targets, mode="train"
    )

    self.log_dict(
        metrics,
        on_step=True,
        on_epoch=False,
        logger=True,
        prog_bar=False,
    )

    metrics["loss"] = loss

    return metrics

validation_epoch_end(self, outputs)

Aggregate metrics of a validation epoch

Parameters:

Name Type Description Default
outputs List[Dict[str, torch.Tensor]]

Aggregated outputs from validation_step

required
Source code in slp/plbind/module.py
def validation_epoch_end(self, outputs):
    """Aggregate metrics of a validation epoch

    Args:
        outputs (List[Dict[str, torch.Tensor]]): Aggregated outputs from validation_step
    """
    outputs = self.aggregate_epoch_metrics(outputs, mode="Validation")

    if torch.isnan(outputs["val_loss"]) or torch.isinf(outputs["val_loss"]):
        outputs["val_loss"] = 1000000

    outputs["best_score"] = min(
        outputs[self.trainer.early_stopping_callback.monitor].detach().cpu(),
        self.trainer.early_stopping_callback.best_score.detach().cpu(),
    )
    self.log_to_console(outputs, mode="Validation")

validation_step(self, batch, batch_idx)

Compute loss for a single validation step and log metrics to loggers

Parameters:

Name Type Description Default
batch Tuple[torch.Tensor, ...]

Input batch

required
batch_idx int

Index of batch

required

Returns:

Type Description
Dict[str, torch.Tensor]

computed metrics

Source code in slp/plbind/module.py
def validation_step(self, batch, batch_idx):
    """Compute loss for a single validation step and log metrics to loggers

    Args:
        batch (Tuple[torch.Tensor, ...]): Input batch
        batch_idx (int): Index of batch

    Returns:
        Dict[str, torch.Tensor]: computed metrics
    """
    y_hat, targets = self.predictor.get_predictions_and_targets(self, batch)
    loss = self.criterion(y_hat, targets)
    metrics = self._compute_metrics(
        self.val_metrics, loss, y_hat, targets, mode="val"
    )

    metrics[
        "best_score"
    ] = self.trainer.early_stopping_callback.best_score.detach().cpu()

    return metrics

TransformerClassificationPLModule

__init__(self, model, optimizer, criterion, lr_scheduler=None, hparams=None, metrics=None, calculate_perplexity=False) special

Pass arguments through to base class

Source code in slp/plbind/module.py
def __init__(
    self,
    model: nn.Module,
    optimizer: Union[Optimizer, List[Optimizer]],
    criterion: LossType,
    lr_scheduler: Union[_LRScheduler, List[_LRScheduler]] = None,
    hparams: Configuration = None,
    metrics: Optional[Dict[str, pl.metrics.Metric]] = None,
    calculate_perplexity=False,
):
    """Pass arguments through to base class"""
    super(TransformerClassificationPLModule, self).__init__(
        model,
        optimizer,
        criterion,
        predictor_cls=_TransformerClassification,
        lr_scheduler=lr_scheduler,
        hparams=hparams,
        metrics=metrics,
        calculate_perplexity=calculate_perplexity,
    )

TransformerPLModule

__init__(self, model, optimizer, criterion, lr_scheduler=None, hparams=None, metrics=None, calculate_perplexity=False) special

Pass arguments through to base class

Source code in slp/plbind/module.py
def __init__(
    self,
    model: nn.Module,
    optimizer: Union[Optimizer, List[Optimizer]],
    criterion: LossType,
    lr_scheduler: Union[_LRScheduler, List[_LRScheduler]] = None,
    hparams: Configuration = None,
    metrics: Optional[Dict[str, pl.metrics.Metric]] = None,
    calculate_perplexity=False,
):
    """Pass arguments through to base class"""
    super(TransformerPLModule, self).__init__(
        model,
        optimizer,
        criterion,
        predictor_cls=_Transformer,
        lr_scheduler=lr_scheduler,
        hparams=hparams,
        metrics=metrics,
        calculate_perplexity=calculate_perplexity,
    )

add_optimizer_args(parent_parser)

Augment parser with optimizer arguments

Parameters:

Name Type Description Default
parent_parser ArgumentParser

Parser created by the user

required

Returns:

Type Description
ArgumentParser

argparse.ArgumentParser: Augmented parser

Source code in slp/plbind/trainer.py
def add_optimizer_args(
    parent_parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
    """Augment parser with optimizer arguments

    Args:
        parent_parser (argparse.ArgumentParser): Parser created by the user

    Returns:
        argparse.ArgumentParser: Augmented parser
    """
    parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
    parser.add_argument(
        "--optimizer",
        dest="optimizer",
        type=str,
        choices=[
            "Adam",
            "AdamW",
            "SGD",
            "Adadelta",
            "Adagrad",
            "Adamax",
            "ASGD",
            "RMSprop",
        ],
        default="Adam",
        help="Which optimizer to use",
    )

    parser.add_argument(
        "--lr",
        dest="optim.lr",
        type=float,
        default=1e-3,
        help="Learning rate",
    )

    parser.add_argument(
        "--weight-decay",
        dest="optim.weight_decay",
        type=float,
        default=0,
        help="Learning rate",
    )

    parser.add_argument(
        "--lr-scheduler",
        dest="lr_scheduler",
        action="store_true",
        # type=str,
        # choices=["ReduceLROnPlateau"],
        help="Use learning rate scheduling. Currently only ReduceLROnPlateau is supported out of the box",
    )

    parser.add_argument(
        "--lr-factor",
        dest="lr_schedule.factor",
        type=float,
        default=0.1,
        help="Multiplicative factor by which LR is reduced. Used if --lr-scheduler is provided.",
    )

    parser.add_argument(
        "--lr-patience",
        dest="lr_schedule.patience",
        type=int,
        default=10,
        help="Number of epochs with no improvement after which learning rate will be reduced. Used if --lr-scheduler is provided.",
    )

    parser.add_argument(
        "--lr-cooldown",
        dest="lr_schedule.cooldown",
        type=int,
        default=0,
        help="Number of epochs to wait before resuming normal operation after lr has been reduced. Used if --lr-scheduler is provided.",
    )

    parser.add_argument(
        "--min-lr",
        dest="lr_schedule.min_lr",
        type=float,
        default=0,
        help="Minimum lr for LR scheduling. Used if --lr-scheduler is provided.",
    )

    return parser

add_trainer_args(parent_parser)

Augment parser with trainer arguments

Parameters:

Name Type Description Default
parent_parser ArgumentParser

Parser created by the user

required

Returns:

Type Description
ArgumentParser

argparse.ArgumentParser: Augmented parser

Source code in slp/plbind/trainer.py
def add_trainer_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
    """Augment parser with trainer arguments

    Args:
        parent_parser (argparse.ArgumentParser): Parser created by the user

    Returns:
        argparse.ArgumentParser: Augmented parser
    """
    parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
    parser.add_argument(
        "--seed",
        dest="seed",
        type=int,
        default=None,
        help="Seed for reproducibility",
    )

    parser.add_argument(
        "--config",
        dest="config",
        type=str,  # dir_path,
        default=None,
        help="Path to YAML configuration file",
    )

    parser.add_argument(
        "--experiment-name",
        dest="trainer.experiment_name",
        type=str,
        default="experiment",
        help="Name of the running experiment",
    )

    parser.add_argument(
        "--run-id",
        dest="trainer.run_id",
        type=str,
        default=None,
        help="Unique identifier for the current run. If not provided it is inferred from datetime.now()",
    )

    parser.add_argument(
        "--experiment-group",
        dest="trainer.experiment_group",
        type=str,
        default=None,
        help="Group of current experiment. Useful when evaluating for different seeds / cross-validation etc.",
    )

    parser.add_argument(
        "--experiments-folder",
        dest="trainer.experiments_folder",
        type=str,
        default="experiments",
        help="Top-level folder where experiment results & checkpoints are saved",
    )

    parser.add_argument(
        "--save-top-k",
        dest="trainer.save_top_k",
        type=int,
        default=3,
        help="Save checkpoints for top k models",
    )

    parser.add_argument(
        "--patience",
        dest="trainer.patience",
        type=int,
        default=3,
        help="Number of epochs to wait before early stopping",
    )

    parser.add_argument(
        "--wandb-project",
        dest="trainer.wandb_project",
        type=str,
        default=None,
        help="Wandb project under which results are saved",
    )

    parser.add_argument(
        "--tags",
        dest="trainer.tags",
        type=str,
        nargs="*",
        default=[],
        help="Tags for current run to make results searchable.",
    )

    parser.add_argument(
        "--stochastic_weight_avg",
        dest="trainer.stochastic_weight_avg",
        action="store_true",
        help="Use Stochastic weight averaging.",
    )

    parser.add_argument(
        "--gpus", dest="trainer.gpus", type=int, default=0, help="Number of GPUs to use"
    )

    parser.add_argument(
        "--val-interval",
        dest="trainer.check_val_every_n_epoch",
        type=int,
        default=1,
        help="Run validation every n epochs",
    )

    parser.add_argument(
        "--clip-grad-norm",
        dest="trainer.gradient_clip_val",
        type=float,
        default=0,
        help="Clip gradients with ||grad(w)|| >= args.clip_grad_norm",
    )

    parser.add_argument(
        "--epochs",
        dest="trainer.max_epochs",
        type=int,
        default=100,
        help="Maximum number of training epochs",
    )

    parser.add_argument(
        "--num-nodes",
        dest="trainer.num_nodes",
        type=int,
        default=1,
        help="Number of nodes to run",
    )

    parser.add_argument(
        "--steps",
        dest="trainer.max_steps",
        type=int,
        default=None,
        help="Maximum number of training steps",
    )

    parser.add_argument(
        "--tbtt_steps",
        dest="trainer.truncated_bptt_steps",
        type=int,
        default=None,
        help="Truncated Back-propagation-through-time steps.",
    )

    parser.add_argument(
        "--debug",
        dest="debug",
        action="store_true",
        help="If true, we run a full run on a small subset of the input data and overfit 10 training batches",
    )

    parser.add_argument(
        "--offline",
        dest="trainer.force_wandb_offline",
        action="store_true",
        help="If true, forces offline execution of wandb logger",
    )

    parser.add_argument(
        "--early-stop-on",
        dest="trainer.early_stop_on",
        type=str,
        default="val_loss",
        help="Metric for early stopping",
    )

    parser.add_argument(
        "--early-stop-mode",
        dest="trainer.early_stop_mode",
        type=str,
        choices=["min", "max"],
        default="min",
        help="Minimize or maximize early stopping metric",
    )

    return parser

make_trainer(experiment_name='experiment', experiment_description=None, run_id=None, experiment_group=None, experiments_folder='experiments', save_top_k=3, patience=3, wandb_project=None, wandb_user=None, force_wandb_offline=False, tags=None, stochastic_weight_avg=False, auto_scale_batch_size=False, gpus=0, check_val_every_n_epoch=1, gradient_clip_val=0, precision=32, num_nodes=1, max_epochs=100, max_steps=None, truncated_bptt_steps=None, fast_dev_run=None, overfit_batches=None, terminate_on_nan=False, profiler='simple', early_stop_on='val_loss', early_stop_mode='min')

Configure trainer with preferred defaults

  • Experiment folder and run_id configured (based on datetime.now())
  • Wandb and CSV loggers run by default
  • Wandb configured to save code and checkpoints
  • Wandb configured in online mode except if no internet connection is available
  • Early stopping on best validation loss is configured by default
  • Checkpointing on best validation loss is configured by default *

Parameters:

Name Type Description Default
experiment_name str

Experiment name. Defaults to "experiment".

'experiment'
experiment_description Optional[str]

Detailed description of the experiment. Defaults to None.

None
run_id Optional[str]

Unique run_id. Defaults to datetime.now(). Defaults to None.

None
experiment_group Optional[str]

Group experiments over multiple runs. Defaults to None.

None
experiments_folder str

Folder to save outputs. Defaults to "experiments".

'experiments'
save_top_k int

Save top k checkpoints. Defaults to 3.

3
patience int

Patience for early stopping. Defaults to 3.

3
wandb_project Optional[str]

Wandb project to save the experiment. Defaults to None.

None
wandb_user Optional[str]

Wandb username. Defaults to None.

None
force_wandb_offline bool

Force offline execution of wandb

False
tags Optional[Sequence]

Additional tags to attach to the experiment. Defaults to None.

None
stochastic_weight_avg bool

Use stochastic weight averaging. Defaults to False.

False
auto_scale_batch_size bool

Find optimal batch size for the available resources when running trainer.tune(). Defaults to False.

False
gpus int

number of GPUs to use. Defaults to 0.

0
check_val_every_n_epoch int

Run validation every n epochs. Defaults to 1.

1
gradient_clip_val float

Clip gradient norm value. Defaults to 0 (no clipping).

0
precision int

Floating point precision. Defaults to 32.

32
num_nodes int

Number of nodes to run on

1
max_epochs Optional[int]

Maximum number of epochs for training. Defaults to 100.

100
max_steps Optional[int]

Maximum number of steps for training. Defaults to None.

None
truncated_bptt_steps Optional[int]

Truncated back prop breaks performs backprop every k steps of much longer sequence. Defaults to None.

None
fast_dev_run Optional[int]

Run training on a small number of batches for debugging. Defaults to None.

None
overfit_batches Optional[int]

Try to overfit a small number of batches for debugging. Defaults to None.

None
terminate_on_nan bool

Terminate on NaN gradients. Warning this makes training slow. Defaults to False.

False
profiler Union[pytorch_lightning.profiler.profilers.BaseProfiler, bool, str]

Use profiler to track execution times of each function

'simple'
early_stop_on str

metric for early stopping

'val_loss'
early_stop_mode str

"min" or "max"

'min'

Returns:

Type Description
Trainer

pl.Trainer: Configured trainer

Source code in slp/plbind/trainer.py
def make_trainer(
    experiment_name: str = "experiment",
    experiment_description: Optional[str] = None,
    run_id: Optional[str] = None,
    experiment_group: Optional[str] = None,
    experiments_folder: str = "experiments",
    save_top_k: int = 3,
    patience: int = 3,
    wandb_project: Optional[str] = None,
    wandb_user: Optional[str] = None,
    force_wandb_offline: bool = False,
    tags: Optional[Sequence] = None,
    stochastic_weight_avg: bool = False,
    auto_scale_batch_size: bool = False,
    gpus: int = 0,
    check_val_every_n_epoch: int = 1,
    gradient_clip_val: float = 0,
    precision: int = 32,
    num_nodes: int = 1,
    max_epochs: Optional[int] = 100,
    max_steps: Optional[int] = None,
    truncated_bptt_steps: Optional[int] = None,
    fast_dev_run: Optional[int] = None,
    overfit_batches: Optional[int] = None,
    terminate_on_nan: bool = False,  # Be careful this makes training very slow for large models
    profiler: Optional[Union[pl.profiler.BaseProfiler, bool, str]] = "simple",
    early_stop_on: str = "val_loss",
    early_stop_mode: str = "min",
) -> pl.Trainer:
    """Configure trainer with preferred defaults

    * Experiment folder and run_id configured (based on datetime.now())
    * Wandb and CSV loggers run by default
    * Wandb configured to save code and checkpoints
    * Wandb configured in online mode except if no internet connection is available
    * Early stopping on best validation loss is configured by default
    * Checkpointing on best validation loss is configured by default
    *

    Args:
        experiment_name (str, optional): Experiment name. Defaults to "experiment".
        experiment_description (Optional[str], optional): Detailed description of the experiment. Defaults to None.
        run_id (Optional[str], optional): Unique run_id. Defaults to datetime.now(). Defaults to None.
        experiment_group (Optional[str], optional): Group experiments over multiple runs. Defaults to None.
        experiments_folder (str, optional): Folder to save outputs. Defaults to "experiments".
        save_top_k (int, optional): Save top k checkpoints. Defaults to 3.
        patience (int, optional): Patience for early stopping. Defaults to 3.
        wandb_project (Optional[str], optional): Wandb project to save the experiment. Defaults to None.
        wandb_user (Optional[str], optional): Wandb username. Defaults to None.
        force_wandb_offline (bool): Force offline execution of wandb
        tags (Optional[Sequence], optional): Additional tags to attach to the experiment. Defaults to None.
        stochastic_weight_avg (bool, optional): Use stochastic weight averaging. Defaults to False.
        auto_scale_batch_size (bool, optional): Find optimal batch size for the available resources when running
                trainer.tune(). Defaults to False.
        gpus (int, optional): number of GPUs to use. Defaults to 0.
        check_val_every_n_epoch (int, optional): Run validation every n epochs. Defaults to 1.
        gradient_clip_val (float, optional): Clip gradient norm value. Defaults to 0 (no clipping).
        precision (int, optional): Floating point precision. Defaults to 32.
        num_nodes (int): Number of nodes to run on
        max_epochs (Optional[int], optional): Maximum number of epochs for training. Defaults to 100.
        max_steps (Optional[int], optional): Maximum number of steps for training. Defaults to None.
        truncated_bptt_steps (Optional[int], optional): Truncated back prop breaks performs backprop every k steps of much longer
                sequence. Defaults to None.
        fast_dev_run (Optional[int], optional): Run training on a small number of  batches for debugging. Defaults to None.
        overfit_batches (Optional[int], optional): Try to overfit a small number of batches for debugging. Defaults to None.
        terminate_on_nan (bool, optional): Terminate on NaN gradients. Warning this makes training slow. Defaults to False.
        profiler (Optional[Union[pl.profiler.BaseProfiler, bool, str]]): Use profiler to track execution times of each function
        early_stop_on (str): metric for early stopping
        early_stop_mode (str): "min" or "max"

    Returns:
        pl.Trainer: Configured trainer
    """

    if overfit_batches is not None:
        trainer = pl.Trainer(overfit_batches=overfit_batches, gpus=gpus)

        return trainer

    if fast_dev_run is not None:
        trainer = pl.Trainer(fast_dev_run=fast_dev_run, gpus=gpus)

        return trainer

    logging_dir = os.path.join(experiments_folder, experiment_name)
    safe_mkdirs(logging_dir)

    run_id = run_id if run_id is not None else date_fname()

    if run_id in os.listdir(logging_dir):
        logger.warning(
            "The run id you provided {run_id} already exists in {logging_dir}"
        )
        run_id = date_fname()
        logger.info("Setting run_id={run_id}")

    checkpoint_dir = os.path.join(logging_dir, run_id, "checkpoints")

    logger.info(f"Logs will be saved in {logging_dir}")
    logger.info(f"Logs will be saved in {checkpoint_dir}")

    if wandb_project is None:
        wandb_project = experiment_name

    connected = has_internet_connection()
    offline_run = force_wandb_offline or not connected

    loggers = [
        pl.loggers.CSVLogger(logging_dir, name="csv_logs", version=run_id),
        FixedWandbLogger(  # type: ignore
            name=experiment_name,
            project=wandb_project,
            anonymous=False,
            save_dir=logging_dir,
            version=run_id,
            save_code=True,
            checkpoint_dir=checkpoint_dir,
            offline=offline_run,
            log_model=not offline_run,
            entity=wandb_user,
            group=experiment_group,
            notes=experiment_description,
            tags=tags,
        ),
    ]

    if gpus > 1:
        del loggers[
            1
        ]  # https://github.com/PyTorchLightning/pytorch-lightning/issues/6106

    logger.info("Configured wandb and CSV loggers.")
    logger.info(
        f"Wandb configured to run {experiment_name}/{run_id} in project {wandb_project}"
    )

    if connected:
        logger.info("Results will be stored online.")
    else:
        logger.info("Results will be stored offline due to bad internet connection.")
        logger.info(
            f"If you want to upload your results later run\n\t wandb sync {logging_dir}/wandb/run-{run_id}"
        )

    if experiment_description is not None:
        logger.info(
            f"Experiment verbose description:\n{experiment_description}\n\nTags:{'n/a' if tags is None else tags}"
        )

    callbacks = [
        EarlyStoppingWithLogs(
            monitor=early_stop_on,
            mode=early_stop_mode,
            patience=patience,
            verbose=True,
        ),
        pl.callbacks.ModelCheckpoint(
            dirpath=checkpoint_dir,
            filename="{epoch}-{val_loss:.2f}",
            monitor=early_stop_on,
            save_top_k=save_top_k,
            mode=early_stop_mode,
        ),
        pl.callbacks.LearningRateMonitor(logging_interval="step"),
    ]

    logger.info("Configured Early stopping and Model checkpointing to track val_loss")

    trainer = pl.Trainer(
        default_root_dir=logging_dir,
        gpus=gpus,
        max_epochs=max_epochs,
        max_steps=max_steps,
        callbacks=callbacks,
        logger=loggers,
        check_val_every_n_epoch=check_val_every_n_epoch,
        gradient_clip_val=gradient_clip_val,
        auto_scale_batch_size=auto_scale_batch_size,
        stochastic_weight_avg=stochastic_weight_avg,
        precision=precision,
        truncated_bptt_steps=truncated_bptt_steps,
        terminate_on_nan=terminate_on_nan,
        progress_bar_refresh_rate=10,
        profiler=profiler,
        num_nodes=num_nodes,
    )

    return trainer

make_trainer_for_ray_tune(patience=3, stochastic_weight_avg=False, gpus=0, gradient_clip_val=0, precision=32, max_epochs=100, max_steps=None, truncated_bptt_steps=None, terminate_on_nan=False, early_stop_on='val_loss', early_stop_mode='min', metrics_map=None, **extra_kwargs)

Configure trainer with preferred defaults

  • Early stopping on best validation loss is configured by default
  • Ray tune callback configured

Parameters:

Name Type Description Default
patience int

Patience for early stopping. Defaults to 3.

3
stochastic_weight_avg bool

Use stochastic weight averaging. Defaults to False.

False
gpus int

number of GPUs to use. Defaults to 0.

0
gradient_clip_val float

Clip gradient norm value. Defaults to 0 (no clipping).

0
precision int

Floating point precision. Defaults to 32.

32
max_epochs Optional[int]

Maximum number of epochs for training. Defaults to 100.

100
max_steps Optional[int]

Maximum number of steps for training. Defaults to None.

None
truncated_bptt_steps Optional[int]

Truncated back prop breaks performs backprop every k steps of much longer sequence. Defaults to None.

None
terminate_on_nan bool

Terminate on NaN gradients. Warning this makes training slow. Defaults to False.

False
early_stop_on str

metric for early stopping

'val_loss'
early_stop_mode str

"min" or "max"

'min'
metrics_map Optional[Dict[str, str]]

The mapping from pytorch lightning logged metrics to ray tune metrics. The --tune-metric argument should be one of the keys of this mapping

None
extra_kwargs kwargs

Ignored. We use it so that we are able to pass the same config object as in make_trainer

{}

Returns:

Type Description
Trainer

pl.Trainer: Configured trainer

Source code in slp/plbind/trainer.py
def make_trainer_for_ray_tune(
    patience: int = 3,
    stochastic_weight_avg: bool = False,
    gpus: int = 0,
    gradient_clip_val: float = 0,
    precision: int = 32,
    max_epochs: Optional[int] = 100,
    max_steps: Optional[int] = None,
    truncated_bptt_steps: Optional[int] = None,
    terminate_on_nan: bool = False,  # Be careful this makes training very slow for large models
    early_stop_on: str = "val_loss",
    early_stop_mode: str = "min",
    metrics_map: Optional[Dict[str, str]] = None,
    **extra_kwargs,
) -> pl.Trainer:
    """Configure trainer with preferred defaults

    * Early stopping on best validation loss is configured by default
    * Ray tune callback configured

    Args:
        patience (int, optional): Patience for early stopping. Defaults to 3.
        stochastic_weight_avg (bool, optional): Use stochastic weight averaging. Defaults to False.
        gpus (int, optional): number of GPUs to use. Defaults to 0.
        gradient_clip_val (float, optional): Clip gradient norm value. Defaults to 0 (no clipping).
        precision (int, optional): Floating point precision. Defaults to 32.
        max_epochs (Optional[int], optional): Maximum number of epochs for training. Defaults to 100.
        max_steps (Optional[int], optional): Maximum number of steps for training. Defaults to None.
        truncated_bptt_steps (Optional[int], optional): Truncated back prop breaks performs backprop every k steps of much longer
                sequence. Defaults to None.
        terminate_on_nan (bool, optional): Terminate on NaN gradients. Warning this makes training slow. Defaults to False.
        early_stop_on (str): metric for early stopping
        early_stop_mode (str): "min" or "max"
        metrics_map (Optional[Dict[str, str]]): The mapping from pytorch lightning logged metrics
            to ray tune metrics. The --tune-metric argument should be one of the keys of this
            mapping
        extra_kwargs (kwargs): Ignored. We use it so that we are able to pass the same config
            object as in make_trainer
    Returns:
        pl.Trainer: Configured trainer
    """

    if metrics_map is None:
        raise ValueError("Need to pass metrics for TuneReportCallback")

    callbacks = [
        EarlyStoppingWithLogs(
            monitor=early_stop_on,
            mode=early_stop_mode,
            patience=patience,
            verbose=True,
        ),
        TuneReportCallback(metrics_map, on="validation_end"),
        pl.callbacks.LearningRateMonitor(logging_interval="step"),
    ]

    logger.info("Configured Early stopping to track val_loss")

    trainer = pl.Trainer(
        gpus=gpus,
        max_epochs=max_epochs,
        max_steps=max_steps,
        callbacks=callbacks,
        logger=[],
        check_val_every_n_epoch=1,
        gradient_clip_val=gradient_clip_val,
        stochastic_weight_avg=stochastic_weight_avg,
        precision=precision,
        truncated_bptt_steps=truncated_bptt_steps,
        terminate_on_nan=terminate_on_nan,
        progress_bar_refresh_rate=0,
        num_sanity_val_steps=0,
        auto_scale_batch_size=False,
    )

    return trainer

watch_model(trainer, model)

If wandb logger is configured track gradient and weight norms

Parameters:

Name Type Description Default
trainer Trainer

Trainer

required
model Module

Module to watch

required
Source code in slp/plbind/trainer.py
def watch_model(trainer: pl.Trainer, model: nn.Module) -> None:
    """If wandb logger is configured track gradient and weight norms

    Args:
        trainer (pl.Trainer): Trainer
        model (nn.Module): Module to watch
    """

    if trainer.num_gpus > 1:
        return

    if isinstance(trainer.logger.experiment, list):
        for log in trainer.logger.experiment:
            try:
                log.watch(model, log="all")
                logger.info("Tracking model weights & gradients in wandb.")

                break
            except:
                pass
    else:
        try:
            trainer.logger.experiment.watch(model, log="all")
            logger.info("Tracking model weights & gradients in wandb.")
        except:
            pass