bpnet

This module contains a reference implementation of BPNet that can be used or adapted for your own circumstances. The implementation takes in a stranded control track and makes predictions for stranded outputs.

class bpnetlite.bpnet.BPNet(n_filters=64, n_layers=8, n_outputs=2, n_control_tracks=2, count_loss_weight=1, profile_output_bias=True, count_output_bias=True, name=None, trimming=None, verbose=True)

A basic BPNet model with stranded profile and total count prediction.

This is a reference implementation for BPNet models. It exactly matches the architecture in the official ChromBPNet repository. It is very similar to the implementation in the official basepairmodels repository but differs in when the activation function is applied for the resifual layers. See the BasePairNet object below for an implementation that matches that repository.

The model takes in one-hot encoded sequence, runs it through:

  1. a single wide convolution operation

THEN

  1. a user-defined number of dilated residual convolutions

THEN

(3a) profile predictions done using a very wide convolution layer that also takes in stranded control tracks

AND

(3b) total count prediction done using an average pooling on the output from 2 followed by concatenation with the log1p of the sum of the stranded control tracks and then run through a dense layer.

This implementation differs from the original BPNet implementation in two ways:

(1) The model concatenates stranded control tracks for profile prediction as opposed to adding the two strands together and also then smoothing that track

(2) The control input for the count prediction task is the log1p of the strand-wise sum of the control tracks, as opposed to the raw counts themselves.

(3) A single log softmax is applied across both strands such that the logsumexp of both strands together is 0. Put another way, the two strands are concatenated together, a log softmax is applied, and the MNLL loss is calculated on the concatenation.

(4) The count prediction task is predicting the total counts across both strands. The counts are then distributed across strands according to the single log softmax from 3.

Parameters

n_filters: int, optional

The number of filters to use per convolution. Default is 64.

n_layers: int, optional

The number of dilated residual layers to include in the model. Default is 8.

n_outputs: int, optional

The number of profile outputs from the model. Generally either 1 or 2 depending on if the data is unstranded or stranded. Default is 2.

n_control_tracks: int, optional

The number of control tracks to feed into the model. When predicting TFs, this is usually 2. When predicting accessibility, this is usualy 0. When 0, this input is removed from the model. Default is 2.

count_loss_weight: float, optional

The weight to put on the count loss.

profile_output_bias: bool, optional

Whether to include a bias term in the final profile convolution. Removing this term can help with attribution stability and will usually not affect performance. Default is True.

count_output_bias: bool, optional

Whether to include a bias term in the linear layer used to predict counts. Removing this term can help with attribution stability but may affect performance. Default is True.

name: str or None, optional

The name to save the model to during training.

trimming: int or None, optional

The amount to trim from both sides of the input window to get the output window. This value is removed from both sides, so the total number of positions removed is 2*trimming.

verbose: bool, optional

Whether to display statistics during training. Setting this to False will still save the file at the end, but does not print anything to screen during training. Default is True.

fit(training_data, optimizer, scheduler=None, X_valid=None, X_ctl_valid=None, y_valid=None, max_epochs=100, batch_size=64, dtype='float32', device='cuda', early_stopping=None)

Fit the model to data and validate it periodically.

This method controls the training of a BPNet model. It will fit the model to examples generated by the training_data DataLoader object and, if validation data is provided, will validate the model against it at the end of each epoch and return those values.

Two versions of the model will be saved: the best model found during training according to the validation measures, and the final model at the end of training. Additionally, a log will be saved of the training and validation statistics, e.g. time and performance.

Parameters

training_data: torch.utils.data.DataLoader

A generator that produces examples to train on. If n_control_tracks is greater than 0, must product two inputs, otherwise must produce only one input.

optimizer: torch.optim.Optimizer

An optimizer to control the training of the model.

scheduler: torch.optim.lr_scheduler, optional

An optional learning rate scheduler which changes the learning rate across batches. If None, do not use a scheduler. Default is None.

X_valid: torch.tensor or None, shape=(n, 4, 2114)

A block of sequences to validate on periodically. If None, do not perform validation. Default is None.

X_ctl_valid: torch.tensor or None, shape=(n, n_control_tracks, 2114)

A block of control sequences to validate on periodically. If n_control_tracks is None, pass in None. Default is None.

y_valid: torch.tensor or None, shape=(n, n_outputs, 1000)

A block of signals to validate against. Must be provided if X_valid is also provided. Default is None.

max_epochs: int

The maximum number of epochs to train for, as measured by the number of times that training_data is exhausted. Default is 100.

batch_size: int, optional

The number of examples to include in each batch. Default is 64.

dtype: str or torch.dtype

The torch.dtype to use when training. Usually, this will be torch.float32 or torch.bfloat16. Default is torch.float32.

device: str

The device to use for training and inference. Typically, this will be ‘cuda’ but can be anything supported by torch. Default is ‘cuda’.

early_stopping: int or None, optional

Whether to stop training early. If None, continue training until max_epochs is reached. If an integer, continue training until that number of epochs has been hit without improvement in performance. Default is None.

forward(X, X_ctl=None)

A forward pass of the model.

This method takes in a nucleotide sequence X, a corresponding per-position value from a control track, and a per-locus value from the control track and makes predictions for the profile and for the counts. This per-locus value is usually the log(sum(X_ctl_profile)+1) when the control is an experimental read track but can also be the output from another model.

Parameters

X: torch.tensor, shape=(batch_size, 4, length)

The one-hot encoded batch of sequences.

X_ctl: torch.tensor or None, shape=(batch_size, n_strands, length)

A value representing the signal of the control at each position in the sequence. If no controls, pass in None. Default is None.

Returns

y_profile: torch.tensor, shape=(batch_size, n_strands, out_length)

The output predictions for each strand trimmed to the output length.

classmethod from_chrombpnet(filename)

Loads a model from ChromBPNet TensorFlow format.

This method will load one of the components of a ChromBPNet model from TensorFlow format. Note that a full ChromBPNet model is made up of an accessibility model and a bias model and that this will load one of the two. Use ChromBPNet.from_chrombpnet to end up with the entire ChromBPNet model.

Parameters

filename: str

The name of the h5 file that stores the trained model parameters.

Returns

model: BPNet

A BPNet model compatible with this repository in PyTorch.

classmethod from_chrombpnet_lite(filename)

Loads a model from ChromBPNet-lite TensorFlow format.

This method will load a ChromBPNet-lite model from TensorFlow format. Note that this is not the same as ChromBPNet format. Specifically, ChromBPNet-lite was a preceeding package that had a slightly different saving format, whereas ChromBPNet is the packaged version of that code that is applied at scale.

This method does not load the entire ChromBPNet model. If that is the desired behavior, see the ChromBPNet object and its associated loading functions. Instead, this loads a single BPNet model – either the bias model or the accessibility model, depending on what is encoded in the stored file.

Parameters

filename: str

The name of the h5 file that stores the trained model parameters.

Returns

model: BPNet

A BPNet model compatible with this repository in PyTorch.

class bpnetlite.bpnet.BasePairNet(n_filters=64, n_layers=8, n_outputs=2, n_control_tracks=2, count_loss_weight=1, profile_output_bias=True, count_output_bias=True, name=None, trimming=None, verbose=True)

A BPNet implementation matching that in basepairmodels

This is a BPNet implementation that matches the one in basepairmodels and can be used to load models trained from that repository, e.g., those trained as part of the atlas project. The architecture of the model is identical to BPNet except that output from the residual layers is added to the pre-activation outputs from the previous layer, rather than to the post-activation outputs from the previous layer. Additionally, the count prediction head takes the sum of the control track counts, adds two instead of one, and then takes the log. Neither detail dramatically changes performance of the model but is necessary to account for when loading trained models.

Parameters

n_filters: int, optional

The number of filters to use per convolution. Default is 64.

n_layers: int, optional

The number of dilated residual layers to include in the model. Default is 8.

n_outputs: int, optional

The number of profile outputs from the model. Generally either 1 or 2 depending on if the data is unstranded or stranded. Default is 2.

n_control_tracks: int, optional

The number of control tracks to feed into the model. When predicting TFs, this is usually 2. When predicting accessibility, this is usualy 0. When 0, this input is removed from the model. Default is 2.

count_loss_weight: float, optional

The weight to put on the count loss.

profile_output_bias: bool, optional

Whether to include a bias term in the final profile convolution. Removing this term can help with attribution stability and will usually not affect performance. Default is True.

count_output_bias: bool, optional

Whether to include a bias term in the linear layer used to predict counts. Removing this term can help with attribution stability but may affect performance. Default is True.

name: str or None, optional

The name to save the model to during training.

trimming: int or None, optional

The amount to trim from both sides of the input window to get the output window. This value is removed from both sides, so the total number of positions removed is 2*trimming.

verbose: bool, optional

Whether to display statistics during training. Setting this to False will still save the file at the end, but does not print anything to screen during training. Default is True.

forward(X, X_ctl=None)

A forward pass of the model.

This method takes in a nucleotide sequence X, a corresponding per-position value from a control track, and a per-locus value from the control track and makes predictions for the profile and for the counts. This per-locus value is usually the log(sum(X_ctl_profile)+1) when the control is an experimental read track but can also be the output from another model.

Parameters

X: torch.tensor, shape=(batch_size, 4, length)

The one-hot encoded batch of sequences.

X_ctl: torch.tensor or None, shape=(batch_size, n_strands, length)

A value representing the signal of the control at each position in the sequence. If no controls, pass in None. Default is None.

Returns

y_profile: torch.tensor, shape=(batch_size, n_strands, out_length)

The output predictions for each strand trimmed to the output length.

classmethod from_bpnet(filename)

Loads a model from BPNet TensorFlow format.

This method will allow you to load a BPNet model from the basepairmodels repo that has been saved in TensorFlow format. You do not need to have TensorFlow installed to use this function. The result will be a model whose predictions and attributions are identical to those produced when using the TensorFlow code.

Parameters

filename: str

The name of the h5 file that stores the trained model parameters.

Returns

model: BPNet

A BPNet model compatible with this repository in PyTorch.

class bpnetlite.bpnet.ControlWrapper(model)

This wrapper automatically creates a control track of all zeroes.

This wrapper will check to see whether the model is expecting a control track (e.g., most BPNet-style models) and will create one with the expected shape. If no control track is expected then it will provide the normal output from the model.

forward(X, X_ctl=None)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class bpnetlite.bpnet.CountWrapper(model)

A wrapper class that only returns the predicted counts.

This class takes in a trained model and returns only the second output. For BPNet models, this means that it is only returning the count predictions. This is for convenience when using captum to calculate attribution scores.

Parameters

model: torch.nn.Module

A torch model to be wrapped.

forward(X, X_ctl=None, **kwargs)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class bpnetlite.bpnet.ProfileWrapper(model)

A wrapper class that returns transformed profiles.

This class takes in a trained model and returns the weighted softmaxed outputs of the first dimension. Specifically, it takes the predicted “logits” and takes the dot product between them and the softmaxed versions of those logits. This is for convenience when using captum to calculate attribution scores.

Parameters

model: torch.nn.Module

A torch model to be wrapped.

forward(X, X_ctl=None, **kwargs)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class bpnetlite.bpnet._ProfileLogitScaling

This ugly class is necessary because of Captum.

Captum internally registers classes as linear or non-linear. Because the profile wrapper performs some non-linear operations, those operations must be registered as such. However, the inputs to the wrapper are not the logits that are being modified in a non-linear manner but rather the original sequence that is subsequently run through the model. Hence, this object will contain all of the operations performed on the logits and can be registered.

Parameters

logits: torch.Tensor, shape=(-1, -1)

The logits as they come out of a Chrom/BPNet model.

forward(logits)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.