Skip to content

Commit

Permalink
add the automatic settings of block size for various lengths of predi…
Browse files Browse the repository at this point in the history
…ction (#38)

* add the option of block size in triangular multiplication

* modify the branch of tri_mul_residual

* update the automatic setting of chunk/block size for various situations
  • Loading branch information
BaozCWJ authored Sep 1, 2022
1 parent cbd7b8a commit b81d59d
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 37 deletions.
1 change: 1 addition & 0 deletions unifold/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def base_config():
},
"globals": {
"chunk_size": chunk_size,
"block_size": None,
"d_pair": d_pair,
"d_msa": d_msa,
"d_template": d_template,
Expand Down
44 changes: 32 additions & 12 deletions unifold/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import gzip
import logging
import math
import numpy as np
import os

Expand All @@ -16,20 +17,34 @@
tensor_tree_map,
)

def get_device_mem(device):
if device != "cpu" and torch.cuda.is_available():
cur_device = torch.cuda.current_device()
prop = torch.cuda.get_device_properties("cuda:{}".format(cur_device))
total_memory_in_GB = prop.total_memory / 1024 / 1024 / 1024
return total_memory_in_GB
else:
return 40

def automatic_chunk_size(seq_len):
if seq_len < 512:
def automatic_chunk_size(seq_len, device, is_bf16):
total_mem_in_GB = get_device_mem(device)
factor = math.sqrt(total_mem_in_GB/40.0*(0.55 * is_bf16 + 0.45))*0.95
if seq_len < int(1024*factor):
chunk_size = 256
elif seq_len < 1024:
block_size = None
elif seq_len < int(2048*factor):
chunk_size = 128
elif seq_len < 2048:
block_size = None
elif seq_len < int(3072*factor):
chunk_size = 64
block_size = None
elif seq_len < int(4096*factor):
chunk_size = 32
elif seq_len < 3072:
chunk_size = 16
block_size = 512
else:
chunk_size = 1
return chunk_size

chunk_size = 4
block_size = 256
return chunk_size, block_size

def load_feature_for_one_target(
config, data_folder, seed=0, is_multimer=False, use_uniprot=False
Expand Down Expand Up @@ -68,8 +83,6 @@ def main(args):
if args.sample_templates:
# enable template samples for diversity
config.data.predict.subsample_templates = True
# faster prediction with large chunk
config.globals.chunk_size = 128
model = AlphaFold(config)

print("start to load params {}".format(args.param_path))
Expand Down Expand Up @@ -110,7 +123,14 @@ def main(args):
use_uniprot=args.use_uniprot,
)
seq_len = batch["aatype"].shape[-1]
model.globals.chunk_size = automatic_chunk_size(seq_len)
# faster prediction with large chunk/block size
chunk_size, block_size = automatic_chunk_size(
seq_len,
args.model_device,
args.bf16
)
model.globals.chunk_size = chunk_size
model.globals.block_size = block_size

with torch.no_grad():
batch = {
Expand Down
3 changes: 3 additions & 0 deletions unifold/modules/alphafold.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def embed_templates_pair_core(self, batch, z, pair_mask, tri_start_attn_mask, tr
tri_end_attn_mask=tri_end_attn_mask,
templ_dim=templ_dim,
chunk_size=self.globals.chunk_size,
block_size=self.globals.block_size,
return_mean=not self.enable_template_pointwise_attention,
)
return t
Expand Down Expand Up @@ -318,6 +319,7 @@ def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev):
z,
msa_mask=feats["extra_msa_mask"],
chunk_size=self.globals.chunk_size,
block_size=self.globals.block_size,
pair_mask=pair_mask,
msa_row_attn_mask=extra_msa_row_mask,
msa_col_attn_mask=None,
Expand Down Expand Up @@ -345,6 +347,7 @@ def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev):
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
chunk_size=self.globals.chunk_size,
block_size=self.globals.block_size,
)
return m, z, s, msa_mask, m_1_prev_emb, z_prev_emb

Expand Down
9 changes: 7 additions & 2 deletions unifold/modules/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ def tri_mul_residual(
dropout_shared_dim,
prob,
training,
chunk_size=None,
block_size,
):
if training or chunk_size is None:
if training:
x, g = outputs
bias, g_bias = module.get_output_bias()
shape = list(x.shape)
Expand All @@ -280,6 +280,11 @@ def tri_mul_residual(
mask,
prob,
)
elif block_size is None:
x, g = outputs
bias, g_bias = module.get_output_bias()
residual += (torch.sigmoid(g + g_bias) * (x + bias))
return residual
else:
# gated is not used here
residual += outputs
Expand Down
13 changes: 9 additions & 4 deletions unifold/modules/evoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def forward(
tri_start_attn_mask: torch.Tensor,
tri_end_attn_mask: torch.Tensor,
chunk_size: Optional[int] = None,
block_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:

if self.outer_product_mean_first:
Expand Down Expand Up @@ -169,21 +170,21 @@ def forward(
z = tri_mul_residual(
self.tri_mul_out,
z,
self.tri_mul_out(z, mask=pair_mask, chunk_size=chunk_size),
self.tri_mul_out(z, mask=pair_mask, block_size=block_size),
self.row_dropout_share_dim,
self.pair_dropout,
self.training,
chunk_size=chunk_size,
block_size=block_size,
)

z = tri_mul_residual(
self.tri_mul_in,
z,
self.tri_mul_in(z, mask=pair_mask, chunk_size=chunk_size),
self.tri_mul_in(z, mask=pair_mask, block_size=block_size),
self.row_dropout_share_dim,
self.pair_dropout,
self.training,
chunk_size=chunk_size,
block_size=block_size,
)

z = bias_dropout_residual(
Expand Down Expand Up @@ -274,6 +275,7 @@ def forward(
tri_start_attn_mask: torch.Tensor,
tri_end_attn_mask: torch.Tensor,
chunk_size: int,
block_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
blocks = [
partial(
Expand All @@ -285,6 +287,7 @@ def forward(
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
chunk_size=chunk_size,
block_size=block_size
)
for b in self.blocks
]
Expand Down Expand Up @@ -355,6 +358,7 @@ def forward(
tri_start_attn_mask: torch.Tensor = None,
tri_end_attn_mask: torch.Tensor = None,
chunk_size: int = None,
block_size: int = None,
) -> torch.Tensor:
_, z, _ = super().forward(
m,
Expand All @@ -366,5 +370,6 @@ def forward(
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
chunk_size=chunk_size,
block_size=block_size
)
return z
21 changes: 12 additions & 9 deletions unifold/modules/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def forward(
tri_start_attn_mask: torch.Tensor,
tri_end_attn_mask: torch.Tensor,
chunk_size: Optional[int] = None,
block_size: Optional[int] = None,
):
if self.tri_attn_first:
s = bias_dropout_residual(
Expand All @@ -185,41 +186,41 @@ def forward(
s = tri_mul_residual(
self.tri_mul_out,
s,
self.tri_mul_out(s, mask=mask, chunk_size=chunk_size),
self.tri_mul_out(s, mask=mask, block_size=block_size),
self.row_dropout_share_dim,
self.dropout,
self.training,
chunk_size=chunk_size
block_size=block_size,
)

s = tri_mul_residual(
self.tri_mul_in,
s,
self.tri_mul_in(s, mask=mask, chunk_size=chunk_size),
self.tri_mul_in(s, mask=mask, block_size=block_size),
self.row_dropout_share_dim,
self.dropout,
self.training,
chunk_size=chunk_size
block_size=block_size,
)
else:
s = tri_mul_residual(
self.tri_mul_out,
s,
self.tri_mul_out(s, mask=mask, chunk_size=chunk_size),
self.tri_mul_out(s, mask=mask, block_size=block_size),
self.row_dropout_share_dim,
self.dropout,
self.training,
chunk_size=chunk_size
block_size=block_size,
)

s = tri_mul_residual(
self.tri_mul_in,
s,
self.tri_mul_in(s, mask=mask, chunk_size=chunk_size),
self.tri_mul_in(s, mask=mask, block_size=block_size),
self.row_dropout_share_dim,
self.dropout,
self.training,
chunk_size=chunk_size
block_size=block_size,
)

s = bias_dropout_residual(
Expand Down Expand Up @@ -293,6 +294,7 @@ def forward(
tri_end_attn_mask: torch.Tensor,
templ_dim: int,
chunk_size: int,
block_size: int,
return_mean: bool,
):
def one_template(i):
Expand All @@ -304,13 +306,14 @@ def one_template(i):
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
chunk_size=chunk_size,
block_size=block_size,
)
for b in self.blocks
],
input=(single_templates[i],),
)
return s

n_templ = len(single_templates)
if n_templ > 0:
new_single_templates = [one_template(0)]
Expand Down
20 changes: 10 additions & 10 deletions unifold/modules/triangle_multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def _chunk_2d(
self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: int = None,
block_size: int = None,
) -> torch.Tensor:

# avoid too small chunk size
chunk_size = max(chunk_size, 256)
# block_size = max(block_size, 256)
new_z = z.new_zeros(z.shape)
dim1 = z.shape[-3]

Expand All @@ -53,10 +53,10 @@ def _chunk_projection(z, mask, a=True):
p *= torch.sigmoid(_slice_linear(z, self.linear_ab_g, a))
return p

num_chunk = (dim1 + chunk_size - 1) // chunk_size
num_chunk = (dim1 + block_size - 1) // block_size
for i in range(num_chunk):
chunk_start = i * chunk_size
chunk_end = min(chunk_start + chunk_size, dim1)
chunk_start = i * block_size
chunk_end = min(chunk_start + block_size, dim1)
if self.outgoing:
a_chunk = _chunk_projection(
z[..., chunk_start:chunk_end, :, :],
Expand All @@ -73,8 +73,8 @@ def _chunk_projection(z, mask, a=True):
a_chunk = a_chunk.transpose(-1, -3)

for j in range(num_chunk):
j_chunk_start = j * chunk_size
j_chunk_end = min(j_chunk_start + chunk_size, dim1)
j_chunk_start = j * block_size
j_chunk_end = min(j_chunk_start + block_size, dim1)
if self.outgoing:
b_chunk = _chunk_projection(
z[..., j_chunk_start:j_chunk_end, :, :],
Expand Down Expand Up @@ -110,7 +110,7 @@ def forward(
self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size=None,
block_size=None,
) -> torch.Tensor:

mask = mask.unsqueeze(-1)
Expand All @@ -119,8 +119,8 @@ def forward(
mask = mask * (mask.shape[-2] ** -0.5)

z = self.layer_norm_in(z)
if not self.training and chunk_size is not None:
return self._chunk_2d(z, mask, chunk_size=chunk_size)
if not self.training and block_size is not None:
return self._chunk_2d(z, mask, block_size=block_size)

g = nn.functional.linear(z, self.linear_g.weight)
if self.training:
Expand Down

0 comments on commit b81d59d

Please sign in to comment.