How To: Train a BPNet model using the Python API

There are two ways one can train a BPNet model using bpnet-lite: (1) using the Python API, as we will discuss here, and (2) using the command-line tools. Each approach has trade-offs, with the Python API likely being more familiar to those who are developers and easier to integrate into other code-bases, and the command-line tool being easier for those who are just looking to train a model as fast as possible. Because bpnet-lite aims to be as light-weight and low-level as possible, the Python API provides a great deal of flexibility to make changes to the standard pipeline.

Here, we will focus on a BPNet model trained to predict CTCF binding in HepG2. Specifically, this is ENCODE Accession ENCSR000BIE. Just in case it is not clear, we use a lot of ENCODE data in these tutorials because it is well organized and easy, but there is no requirement that the data be from the ENCODE Portal. All you need are bigWigs with the signals as well as peaks and negatives (which can be calculated from the peaks using bpnet negatives).

Loading the Data

The first step is to provide filenames for where your data are. Unlike the command-line tool, the Python API does not handle the preprocessing of data, so you will need to provide bigWigs with your signal as well as peak and negative files and the genome you want to use. The peaks and negatives can be either bed or bed.gz and can actually be hosted remotely if necessary.

[1]:
seq = "/home/jmschrei/common/hg38.fa"
controls = ['ctcf-bpnet/ctcf-hepg2-control.+.bw', 'ctcf-bpnet/ctcf-hepg2-control.-.bw']
signals = ['ctcf-bpnet/ctcf-hepg2.+.bw', 'ctcf-bpnet/ctcf-hepg2.-.bw']

peaks = 'ctcf-bpnet/ENCFF199YFA.bed.gz'
negatives = 'ctcf-bpnet/ctcf-hepg2.negatives.bed'

Next, we wrap the training data in a generator that will be used to train the model. This generator starts off by loading all of the peaks and negatives into memory with padding on either side equal to the maximum jitter. This includes both the sequence, which will be input to the model, the signal, which is being predicted, and the control tracks, which are optional inputs to the model. During training, slices of these padded sequences are taken to simulate the process of jittering the sequences while still allowing pre-loaded of the sequences into memory so we do not need to load from disk. We will also specify which chromosomes to use for training: this allows us to pass the same peak and negative file into each of the data loaders and have them subsequently filter out examples not on the training chromosomes.

[2]:
from bpnetlite.io import PeakGenerator

training_chroms = [
    "chr2", "chr4", "chr5", "chr7", "chr9", "chr10", "chr11", "chr12",
    "chr13", "chr14", "chr15", "chr16", "chr17", "chr18", "chr19",
    "chr21", "chr22", "chrX", "chrY"
]

training_data = PeakGenerator(
    peaks=peaks,
    negatives=negatives,
    sequences=seq,
    signals=signals,
    controls=controls,
    chroms=training_chroms,
    negative_ratio=0.33,
    random_state=0,
    batch_size=64,
    verbose=True
)
Loading Loci: 100%|█████████████████████████████████████████████████████████████| 43391/43391 [00:04<00:00, 8938.63it/s]
Loading Loci: 100%|████████████████████████████████████████████████████████████| 42358/42358 [00:04<00:00, 10178.07it/s]

Filtered Peaks: 197
Filtered Negatives: 0

Looks like a few examples got filtered. This happens when the extracted example falls off the side of the chromosomes, fall within the blacklist regions (not provided here, but can be provided), or have to many N characters in them.

Next, we will load up the validation data. Because we do not need padding and do not need to sample batches for training, we can use the default extract loci function.

[3]:
from tangermeme.io import extract_loci

validation_chroms = ['chr8', 'chr20']

X_valid, y_valid, X_ctl_valid = extract_loci(
    sequences=seq,
    signals=signals,
    in_signals=controls,
    loci=peaks,
    chroms=validation_chroms,
    max_jitter=0,
    verbose=True
)

X_valid.shape, X_ctl_valid.shape, y_valid.shape
Loading Loci: 100%|██████████████████████████████████████████████████████████████| 4780/4780 [00:00<00:00, 10051.56it/s]
[3]:
(torch.Size([4780, 4, 2114]),
 torch.Size([4780, 2, 2114]),
 torch.Size([4780, 2, 1000]))

BPNet Model and Training

Next, we can define the BPNet model architecture parameters. There are many you can modify, including the number of filters and the number of layers, but the most important ones are the number of outputs and the number of control tracks. Because we are predicting stranded outputs here, and because we have control tracks, we will set both to 2, but they do not need to be set to the same value. A note, though, is that BPNet only predicts a single log count value regardless of the number of output tracks. This is the normal behavior when predicting chromatin accessibility or transcription factor binding, but may result in odd consequences if you start making predictions for a bunch of different tracks.

An important consideration when training BPNet models is setting count_loss_weight, which is a weight on the count loss. In the command-line API, this is automatically calculated for you, but it must be manually set when you use the Python API. It is calculated as the sum of the number of reads in an example, averaged across examples and strands. The pipeline uses the following code to determine it:

if parameters['count_loss_weight'] is None:
    peak_read_count = training_data.dataset.peak_signals.sum(axis=-1)
    count_loss_weight = peak_read_count.mean(axis=(0, 1)).item()
    parameters['count_loss_weight'] = count_loss_weight

Here, we are using the value derived from the official repository so that we can have a fair comparison later on.

[4]:
from bpnetlite.bpnet import BPNet

model = BPNet(
    n_outputs=2,
    n_control_tracks=2,
    count_loss_weight=124.035,
    name='ctcf-bpnet/model',
    verbose=True
).cuda()

Now, we have to specify the optimizer. Vanilla Adam with a learning rate of 0.001 is the default.

[5]:
import torch

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

BPNet also uses a scheduler that cuts the learning rate in half when a specified number of epochs have been reached without improving the loss. Note that this plays with the early stopping that is built into the fit function, which will usually terminate the training procedure before reaching the threshold and minimum learning rates set here.

[6]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5,
    patience=5, threshold=0.0001, min_lr=0.0001)

Then, we can train the model using the fit function. Here, we pass in a few training specific hyperparameters as well as the training and validation data.

[7]:
model.fit(
    training_data, optimizer, scheduler,
    X_valid=X_valid,
    X_ctl_valid=X_ctl_valid,
    y_valid=y_valid,
    max_epochs=50,
    batch_size=128,
    early_stopping=5,
    device='cuda'
)
Warning: BPNet and ChromBPNet models trained using bpnet-lite may underperform those trained using the official repositories. See the GitHub README for further documentation.
Epoch   Iteration       Training Time   Validation Time Training MNLL   Training Count MSE      Validation MNLL Validation Profile Pearson      Validation Count Pearson        Validation Count MSE    Saved?
0       898     6.6875  0.3842  381.8342        0.5251  407.2075        0.41983423      0.59712267      0.4542  True
1       1796    6.1871  0.1942  350.2059        0.3744  393.9158        0.43557942      0.71063036      0.3375  True
2       2694    6.2182  0.1949  411.1587        0.3707  391.1888        0.44186798      0.7465077       0.2894  True
3       3592    6.161   0.1953  373.5336        0.3177  390.0112        0.4454455       0.76701236      0.2678  True
4       4490    6.1695  0.1964  307.6154        0.2309  388.776 0.44791612      0.7800946       0.2423  True
5       5388    6.1677  0.1968  332.8315        0.224   389.8837        0.4490773       0.78811306      0.2493  False
6       6286    6.2306  0.2099  322.2471        0.1584  387.48  0.44984755      0.7793525       0.3057  False
7       7184    6.0064  0.1968  346.826 0.3733  386.7539        0.45105067      0.7895414       0.2312  True
8       8082    5.9378  0.2005  336.1071        0.2456  386.9609        0.45094076      0.79110396      0.2436  False
9       8980    5.9679  0.218   365.19  0.3085  386.4074        0.45112094      0.7961907       0.2457  False
10      9878    6.0746  0.201   381.6792        0.2964  387.5968        0.45169595      0.80174756      0.2299  False
11      10776   5.9657  0.209   399.4953        0.3386  386.7601        0.4526868       0.79971445      0.2585  False
12      11674   5.967   0.232   315.4037        0.2407  387.1695        0.45198223      0.8061229       0.2258  True
13      12572   5.9638  0.2013  338.3375        0.2065  385.9676        0.45231536      0.8021382       0.2279  True
14      13470   5.9533  0.2008  345.3364        0.2405  385.7491        0.45258027      0.80415 0.2123  True
15      14368   5.9497  0.2008  312.7818        0.1463  386.3974        0.45249468      0.79889864      0.2428  False
16      15266   5.9485  0.2009  309.6071        0.2257  386.8736        0.45289117      0.8106955       0.2307  False
17      16164   6.0055  0.1993  412.3003        0.1695  385.9764        0.45309043      0.8117003       0.2029  True
18      17062   5.9617  0.1963  384.7162        0.2055  386.1438        0.45205492      0.80845743      0.263   False
19      17960   5.956   0.1965  356.8955        0.1953  385.979 0.45251903      0.8099449       0.2131  False
20      18858   5.957   0.1962  374.5444        0.2541  386.2297        0.45311397      0.8115832       0.2191  False
21      19756   5.9579  0.1961  303.9682        0.2139  385.4861        0.45325646      0.81220305      0.2099  False
22      20654   5.957   0.1965  364.9231        0.241   385.7539        0.4534836       0.81158155      0.2145  False

Running this function will print a log to the screen containing some important statistics about training, and save this same log to disk under .log where is specified when you create the BPNet object.

Comparison to Official BPNet

We can now compare the performance of this model to the official TensorFlow model to ensure that they are of comparable performance. First, we will load the model from the tarball uploaded to the ENCODE Portal for this experiment. See the tutorial on loading models for more details on what we are doing here.

[8]:
import tarfile

from io import BytesIO
from bpnetlite.bpnet import BasePairNet

with tarfile.open("ctcf-bpnet/ENCFF328YVP.tar.gz", "r:gz") as tar:
    model_tar = tar.extractfile("./fold_0/model.fold_0.ENCSR000BIE.h5").read()

official_bpnet = BasePairNet.from_bpnet(
    BytesIO(model_tar)
)

official_bpnet
[8]:
BasePairNet(
  (iconv): Conv1d(4, 64, kernel_size=(21,), stride=(1,), padding=(10,))
  (irelu): ReLU()
  (rconvs): ModuleList(
    (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
    (1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,))
    (2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(8,))
    (3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(16,))
    (4): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(32,), dilation=(32,))
    (5): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(64,), dilation=(64,))
    (6): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(128,), dilation=(128,))
    (7): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(256,), dilation=(256,))
  )
  (rrelus): ModuleList(
    (0-7): 8 x ReLU()
  )
  (fconv): Conv1d(66, 2, kernel_size=(75,), stride=(1,), padding=(37,))
  (linear): Linear(in_features=65, out_features=1, bias=True)
)

Now, we can make predictions for both models on the validation data, making sure to pass in the control tracks for both.

[9]:
from tangermeme.predict import predict

new_y_logits, new_y_logcounts = predict(model, X_valid, args=(X_ctl_valid,))
official_y_logits, official_y_logcounts = predict(official_bpnet, X_valid, args=(X_ctl_valid,))

Then, we can use the built-in performance measures to fairly compare the performance of the two models.

[10]:
from bpnetlite.performance import calculate_performance_measures

lite_perfs = calculate_performance_measures(new_y_logits, y_valid, new_y_logcounts)
off_perfs = calculate_performance_measures(official_y_logits, y_valid, official_y_logcounts)

We can start off by looking at the log count Pearson.

[11]:
lite_perfs['count_pearson'], off_perfs['count_pearson']
[11]:
(tensor([0.8116]), tensor([0.8123]))

Looks like the two are of similar performance.

Next, we can look at the profile Pearson.

[12]:
import numpy

numpy.nan_to_num(lite_perfs['profile_pearson']).mean(), numpy.nan_to_num(off_perfs['profile_pearson']).mean()
[12]:
(0.45348364, 0.4504725)

Again, it looks like they are of similar performance.

As a more visual comparison of the two, we can take a look at what the predictions from the two models are at a locus of interest.

[13]:
%matplotlib inline
from matplotlib import pyplot as plt
import seaborn; seaborn.set_style('whitegrid')

plt.figure(figsize=(8, 5))
plt.subplot(311)
plt.title("Mapped Experimental Reads")
plt.plot(y_valid[1793].T)
plt.ylabel("Read Count")
seaborn.despine(bottom=True, left=True)

plt.subplot(312)
plt.title("bpnet-lite trained BPNet model")
plt.plot(torch.softmax(new_y_logits[1793], dim=-1).T)
plt.ylabel("Predicted\nProbability")
seaborn.despine(bottom=True, left=True)

plt.subplot(313)
plt.title("Official BPNet Model")
plt.ylabel("Predicted\nProbability")
plt.plot(torch.softmax(official_y_logits[1793], dim=-1).T)
seaborn.despine(bottom=True, left=True)

plt.tight_layout()
plt.show()
../_images/tutorials_How-To-Train-BPNet-Python_26_0.png

Ut sems like the predicctions are pretty similar for both of them, with both models getting the positioning of the primary peak as well as the much weaker peak to the left.

Finally, we can look at the attributions for the models at this locus.

[14]:
from tangermeme.deep_lift_shap import deep_lift_shap

from bpnetlite.bpnet import ControlWrapper
from bpnetlite.bpnet import CountWrapper

X_attr0 = deep_lift_shap(CountWrapper(ControlWrapper(model)), X_valid[1793:1794])
X_attr1 = deep_lift_shap(CountWrapper(ControlWrapper(official_bpnet)), X_valid[1793:1794])
[15]:
from tangermeme.plot import plot_logo

plt.figure(figsize=(10, 3))
plt.subplot(211)
plt.title("DeepLIFT/SHAP of bpnet-lite trained model")
plt.ylabel("Attribution")
plot_logo(X_attr0[0, :, 1000:1100])
plt.grid(False)

plt.subplot(212)
plt.title("DeepLIFT/SHAP of official model")
plt.ylabel("Attribution")
plot_logo(X_attr1[0, :, 1000:1100])
plt.grid(False)



plt.tight_layout()
plt.show()
../_images/tutorials_How-To-Train-BPNet-Python_29_0.png

Looks like they are quite similar – perhaps even more similar than the predictions, with the two G’s to the left of the motif both being highlighted amidst negative attribution characters on either side.

All together, it looks like bpnet-lite produces models using the Python API whose performance is comparable with the official models and attributions/predictions quite similar.