forked from mahmoodlab/UNI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcell_patches.py
110 lines (94 loc) · 3.83 KB
/
cell_patches.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import logging
import numpy as np
import pandas as pd
import tifffile as tiff
import torch
from PIL import Image
from uni import get_encoder
logging.getLogger().setLevel(logging.INFO)
class Uni:
def __init__(self, img_path, csv_path, out_path, batch_size=50):
self.img_path = img_path
self.csv_path = csv_path
self.out_path = out_path
self.batch_size = batch_size
self.offset = 224 // 2
self.load_image()
self.load_labels()
self.filter_labels()
self.load_model()
self.max_idx = len(self.csv)
self.start_idx = self.init_csv()
def load_image(self):
"""Load image into memory."""
logging.info('Loading image.')
self.img = tiff.imread(self.img_path).transpose(1, 2, 0)
def load_labels(self):
"""Load cell information."""
cols = {'CellID': 'id', 'X_centroid': 'x', 'Y_centroid': 'y'}
self.csv = pd.read_csv(self.csv_path, usecols=cols.keys()).rename(columns=cols)
def filter_labels(self):
"""Filter cells near image boundaries."""
logging.info('Filtering labels.')
img_height, img_width, _ = self.img.shape
self.csv = self.csv[
(self.csv.x > self.offset) &
(self.csv.x < img_width - self.offset) &
(self.csv.y > self.offset) &
(self.csv.y < img_height - self.offset)
].reset_index(drop=True)
def load_model(self):
"""Load UNI."""
logging.info('Loading model.')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model, self.transform = get_encoder(enc_name='uni', device=self.device)
def init_csv(self):
"""Initialize output file."""
if os.path.isfile(self.out_path):
logging.info('Resuming processing from last saved cell.')
return len(pd.read_csv(self.out_path, usecols=['cell_id']))
else:
logging.info('Creating new file.')
cols = ['cell_id'] + [f'uni{i}' for i in range(1, 1025)]
pd.DataFrame(columns=cols).to_csv(self.out_path, index=False)
return 0
def run(self):
"""Main function."""
for idx in range(self.start_idx, self.max_idx, self.batch_size):
logging.info(f'Processing cell {idx} of {self.max_idx}.')
self.process_batch(idx)
logging.info('Finished processing.')
def process_batch(self, idx):
"""Process batch."""
indices = np.arange(idx, min(idx + self.batch_size, self.max_idx))
ids = self.csv.id.iloc[indices]
loc = self.csv.iloc[indices, 1:].to_numpy(dtype=np.uint32)
x = self.get_patches(loc)
emb = self.embed_patches(x)
df = pd.DataFrame(emb.cpu().numpy())
df.insert(0, 'cell_id', ids)
df.to_csv(self.out_path, mode='a', header=False, index=False)
def get_patches(self, loc):
"""Generate patches."""
x = [self.crop_cell(xcenter, ycenter) for xcenter, ycenter in loc]
x = [Image.fromarray(img) for img in x]
x = [self.transform(img) for img in x]
return torch.stack(x)
def crop_cell(self, xcenter, ycenter):
"""Generate one patch."""
xstart, xend = xcenter - self.offset, xcenter + self.offset
ystart, yend = ycenter - self.offset, ycenter + self.offset
return self.img[ystart:yend, xstart:xend]
def embed_patches(self, x):
"""Transform patches into 1024-dimensional embeddings."""
x = x.to(self.device)
with torch.inference_mode():
return self.model(x)
if __name__ == '__main__':
uni = Uni(
img_path='/n/scratch/users/d/daf179/melanoma/LSP26239_postHE_reg.ome.tif',
csv_path='/n/scratch/users/d/daf179/melanoma/ML/mitosis_balanced.csv',
out_path='embeddings.csv'
)
uni.run()