Skip to content

Commit

Permalink
Merge pull request #96 from aertslab/render_plot
Browse files Browse the repository at this point in the history
Functional rework small fixes
  • Loading branch information
LukasMahieu authored Feb 13, 2025
2 parents 76c49a2 + a006d39 commit f457204
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 114 deletions.
6 changes: 2 additions & 4 deletions src/crested/pl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,14 @@ def render_plot(
label.set_fontsize(y_tick_fontsize)
label.set_rotation(y_label_rotation)
if tight_rect:
plt.tight_layout(rect=tight_rect)
fig.tight_layout(rect=tight_rect)
else:
plt.tight_layout()
fig.tight_layout()
if save_path:
plt.savefig(save_path)

if show:
plt.show()
else:
plt.close(fig)

if not show and not save_path:
return fig
2 changes: 2 additions & 0 deletions src/crested/tl/_crested.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def load_model(self, model_path: os.PathLike, compile: bool = True) -> None:
the model weights (e.g. when finetuning a model). If False, you should
provide a TaskConfig to the Crested object before calling fit.
"""
if compile and self.config is not None:
logger.warning("Loading a model with compile=True. The CREsted config object will be ignored.")
self.model = keras.models.load_model(model_path, compile=compile)

def fit(
Expand Down
111 changes: 1 addition & 110 deletions src/crested/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,116 +14,7 @@

from crested._genome import Genome, _resolve_genome
from crested._io import _extract_tracks_from_bigwig


def get_hot_encoding_table(
alphabet: str = "ACGT",
neutral_alphabet: str = "N",
neutral_value: float = 0.0,
dtype=np.float32,
) -> np.ndarray:
"""Get hot encoding table to encode a DNA sequence to a numpy array with shape (len(sequence), len(alphabet)) using bytes."""

def str_to_uint8(string) -> np.ndarray:
"""Convert string to byte representation."""
return np.frombuffer(string.encode("ascii"), dtype=np.uint8)

# 256 x 4
hot_encoding_table = np.zeros(
(np.iinfo(np.uint8).max + 1, len(alphabet)), dtype=dtype
)

# For each ASCII value of the nucleotides used in the alphabet
# (upper and lower case), set 1 in the correct column.
hot_encoding_table[str_to_uint8(alphabet.upper())] = np.eye(
len(alphabet), dtype=dtype
)
hot_encoding_table[str_to_uint8(alphabet.lower())] = np.eye(
len(alphabet), dtype=dtype
)

# For each ASCII value of the nucleotides used in the neutral alphabet
# (upper and lower case), set neutral_value in the correct column.
hot_encoding_table[str_to_uint8(neutral_alphabet.upper())] = neutral_value
hot_encoding_table[str_to_uint8(neutral_alphabet.lower())] = neutral_value

return hot_encoding_table


HOT_ENCODING_TABLE = get_hot_encoding_table()


def one_hot_encode_sequence(sequence: str, expand_dim: bool = True) -> np.ndarray:
"""
One hot encode a DNA sequence.
Will return a numpy array with shape (1, len(sequence), 4) if expand_dim is True, otherwise (len(sequence),4).
Alphabet is ACGT.
Parameters
----------
sequence
The DNA sequence to one hot encode.
expand_dim
Whether to expand the dimensions of the output array.
Returns
-------
The one hot encoded DNA sequence.
"""
if expand_dim:
return np.expand_dims(
HOT_ENCODING_TABLE[np.frombuffer(sequence.encode("ascii"), dtype=np.uint8)],
axis=0,
)
else:
return HOT_ENCODING_TABLE[
np.frombuffer(sequence.encode("ascii"), dtype=np.uint8)
]


def generate_mutagenesis(x, include_original=True, flanks=(0, 0)):
"""Generate all possible single point mutations in a sequence."""
_, L, A = x.shape
start, end = 0, L
x_mut = []
start = flanks[0]
end = L - flanks[1]
for length in range(start, end):
for a in range(A):
if not include_original:
if x[0, length, a] == 1:
continue
x_new = np.copy(x)
x_new[0, length, :] = 0
x_new[0, length, a] = 1
x_mut.append(x_new)
return np.concatenate(x_mut, axis=0)


def generate_motif_insertions(x, motif, flanks=(0, 0), masked_locations=None):
"""Generate motif insertions in a sequence."""
_, L, A = x.shape
start, end = 0, L
x_mut = []
motif_length = motif.shape[1]
start = flanks[0]
end = L - flanks[1] - motif_length + 1
insertion_locations = []

for motif_start in range(start, end):
motif_end = motif_start + motif_length
if masked_locations is not None:
if np.any(
(motif_start <= masked_locations) & (masked_locations < motif_end)
):
continue
x_new = np.copy(x)
x_new[0, motif_start:motif_end, :] = motif
x_mut.append(x_new)
insertion_locations.append(motif_start)

return np.concatenate(x_mut, axis=0), insertion_locations
from crested.utils._seq_utils import one_hot_encode_sequence


class EnhancerOptimizer:
Expand Down

0 comments on commit f457204

Please sign in to comment.