Skip to content

Commit

Permalink
FIX: use inceptionv4 from timm (#69)
Browse files Browse the repository at this point in the history
* use inceptionv4 from timm

This changes the model name from inceptionv4 to inception_v4 and allows
us to remove the inceptionv4 implementation from the codebase.

* use inception_v4nobn with help of timm

Use the timm implementation but remove batchnorm (and add bias to the
conv layer before batchnorm).

This changes the name from inceptionv4nobn to inception_v4nobn.

* rm unused default_cfgs

* ignore timm types
kaczmarj authored Jan 15, 2023
1 parent 1cc6f6a commit 1683bba
Showing 7 changed files with 127 additions and 387 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -89,7 +89,7 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-pandas]
ignore_missing_imports = True
[mypy-timm]
[mypy-timm.*]
ignore_missing_imports = True
[mypy-scipy.stats]
ignore_missing_imports = True
6 changes: 3 additions & 3 deletions tests/test_all.py
Original file line number Diff line number Diff line change
@@ -216,9 +216,9 @@ def test_cli_run_args(tmp_path: Path):
350,
144,
),
# Inceptionv4 TCGA-BRCA-v1
# Inception_v4 TCGA-BRCA-v1
(
"inceptionv4",
"inception_v4",
"TCGA-BRCA-v1",
["notumor", "tumor"],
[0.9564113020896912, 0.043588679283857346],
@@ -227,7 +227,7 @@ def test_cli_run_args(tmp_path: Path):
),
# Inceptionv4nobn TCGA-TILs-v1
(
"inceptionv4nobn",
"inception_v4nobn",
"TCGA-TILs-v1",
["notils", "tils"],
[1.0, 3.427359524660334e-12],
306 changes: 0 additions & 306 deletions wsinfer/_modellib/inceptionv4.py

This file was deleted.

190 changes: 119 additions & 71 deletions wsinfer/_modellib/inceptionv4_no_batchnorm.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
# BSD 3-Clause License
#
# Copyright (c) 2017, Remi Cadene
# All rights reserved.
#
# Downloaded from
# https://raw.githubusercontent.com/Cadene/pretrained-models.pytorch/e07fb68c317880e780eb5ca9c20cca00f2584878/pretrainedmodels/models/inceptionv4.py # noqa: E501
#
# We downloaded this file here so we did not have to add pretrainedmodels as a
# dependency (we only use this module).
#
# Modified to not use batchnorm. Models trained with TF Slim do not use batchnorm.
# https://raw.githubusercontent.com/rwightman/pytorch-image-models/e9aac412de82310e6905992e802b1ee4dc52b5d1/timm/models/inception_v4.py
"""
Pytorch Inception-V4 implementation
Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
based upon Google's Tensorflow implementation and pretrained weights
(Apache 2.0 License).
This source was copied into the wsinfer source code and modified to remove batchnorm.
Bias terms are added wherever batchnorm is removed.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ["InceptionV4", "inceptionv4"]
# from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.models import register_model
from timm.models.helpers import build_model_with_cfg
from timm.models.layers import create_classifier


class BasicConv2d(nn.Module):
@@ -27,19 +28,21 @@ def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=True, # Changed this to True after removing batchnorm.
bias=True, # Set to True after removing BatchNorm.
)
# self.bn = nn.BatchNorm2d(out_planes, eps=0.001)
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
x = self.conv(x)
# x = self.bn(x)
x = self.relu(x)
return x


class Mixed_3a(nn.Module):
class Mixed3a(nn.Module):
def __init__(self):
super(Mixed_3a, self).__init__()
super(Mixed3a, self).__init__()
self.maxpool = nn.MaxPool2d(3, stride=2)
self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2)

@@ -50,9 +53,9 @@ def forward(self, x):
return out


class Mixed_4a(nn.Module):
class Mixed4a(nn.Module):
def __init__(self):
super(Mixed_4a, self).__init__()
super(Mixed4a, self).__init__()

self.branch0 = nn.Sequential(
BasicConv2d(160, 64, kernel_size=1, stride=1),
@@ -73,9 +76,9 @@ def forward(self, x):
return out


class Mixed_5a(nn.Module):
class Mixed5a(nn.Module):
def __init__(self):
super(Mixed_5a, self).__init__()
super(Mixed5a, self).__init__()
self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2)
self.maxpool = nn.MaxPool2d(3, stride=2)

@@ -86,9 +89,9 @@ def forward(self, x):
return out


class Inception_A(nn.Module):
class InceptionA(nn.Module):
def __init__(self):
super(Inception_A, self).__init__()
super(InceptionA, self).__init__()
self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1)

self.branch1 = nn.Sequential(
@@ -116,9 +119,9 @@ def forward(self, x):
return out


class Reduction_A(nn.Module):
class ReductionA(nn.Module):
def __init__(self):
super(Reduction_A, self).__init__()
super(ReductionA, self).__init__()
self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2)

self.branch1 = nn.Sequential(
@@ -137,9 +140,9 @@ def forward(self, x):
return out


class Inception_B(nn.Module):
class InceptionB(nn.Module):
def __init__(self):
super(Inception_B, self).__init__()
super(InceptionB, self).__init__()
self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1)

self.branch1 = nn.Sequential(
@@ -170,9 +173,9 @@ def forward(self, x):
return out


class Reduction_B(nn.Module):
class ReductionB(nn.Module):
def __init__(self):
super(Reduction_B, self).__init__()
super(ReductionB, self).__init__()

self.branch0 = nn.Sequential(
BasicConv2d(1024, 192, kernel_size=1, stride=1),
@@ -196,9 +199,9 @@ def forward(self, x):
return out


class Inception_C(nn.Module):
class InceptionC(nn.Module):
def __init__(self):
super(Inception_C, self).__init__()
super(InceptionC, self).__init__()

self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1)

@@ -251,53 +254,98 @@ def forward(self, x):


class InceptionV4(nn.Module):
def __init__(self, num_classes=1001):
def __init__(
self,
num_classes=1000,
in_chans=3,
output_stride=32,
drop_rate=0.0,
global_pool="avg",
):
super(InceptionV4, self).__init__()
# Special attributs
self.input_space = None
self.input_size = (299, 299, 3)
self.mean = None
self.std = None
# Modules
assert output_stride == 32
self.drop_rate = drop_rate
self.num_classes = num_classes
self.num_features = 1536

self.features = nn.Sequential(
BasicConv2d(3, 32, kernel_size=3, stride=2),
BasicConv2d(in_chans, 32, kernel_size=3, stride=2),
BasicConv2d(32, 32, kernel_size=3, stride=1),
BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
Mixed_3a(),
Mixed_4a(),
Mixed_5a(),
Inception_A(),
Inception_A(),
Inception_A(),
Inception_A(),
Reduction_A(), # Mixed_6a
Inception_B(),
Inception_B(),
Inception_B(),
Inception_B(),
Inception_B(),
Inception_B(),
Inception_B(),
Reduction_B(), # Mixed_7a
Inception_C(),
Inception_C(),
Inception_C(),
Mixed3a(),
Mixed4a(),
Mixed5a(),
InceptionA(),
InceptionA(),
InceptionA(),
InceptionA(),
ReductionA(), # Mixed6a
InceptionB(),
InceptionB(),
InceptionB(),
InceptionB(),
InceptionB(),
InceptionB(),
InceptionB(),
ReductionB(), # Mixed7a
InceptionC(),
InceptionC(),
InceptionC(),
)
self.last_linear = nn.Linear(1536, num_classes)

def logits(self, features):
# Allows image of any size to be processed
adaptiveAvgPoolWidth = features.shape[2]
x = F.avg_pool2d(features, kernel_size=adaptiveAvgPoolWidth)
x = x.view(x.size(0), -1)
x = self.last_linear(x)
return x
self.feature_info = [
dict(num_chs=64, reduction=2, module="features.2"),
dict(num_chs=160, reduction=4, module="features.3"),
dict(num_chs=384, reduction=8, module="features.9"),
dict(num_chs=1024, reduction=16, module="features.17"),
dict(num_chs=1536, reduction=32, module="features.21"),
]
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool
)

@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(stem=r"^features\.[012]\.", blocks=r"^features\.(\d+)")

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, "gradient checkpointing not supported"

@torch.jit.ignore
def get_classifier(self):
return self.last_linear

def forward(self, input):
x = self.features(input)
x = self.logits(x)
def reset_classifier(self, num_classes, global_pool="avg"):
self.num_classes = num_classes
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool
)

def forward_features(self, x):
return self.features(x)

def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return x if pre_logits else self.last_linear(x)

def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x


def inceptionv4(num_classes=1000):
return InceptionV4(num_classes=num_classes)
def _create_inception_v4(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
InceptionV4,
variant,
pretrained,
feature_cfg=dict(flatten_sequential=True),
**kwargs
)


@register_model
def inception_v4nobn(pretrained=False, **kwargs):
return _create_inception_v4("inception_v4nobn", pretrained, **kwargs)
6 changes: 2 additions & 4 deletions wsinfer/_modellib/models.py
Original file line number Diff line number Diff line change
@@ -15,8 +15,8 @@
from torch.hub import load_state_dict_from_url
import yaml

from .inceptionv4 import inceptionv4 as _inceptionv4
from .inceptionv4_no_batchnorm import inceptionv4 as _inceptionv4_no_bn
# Imported for side effects of registering model.
from . import inceptionv4_no_batchnorm as _ # noqa
from .resnet_preact import resnet34_preact as _resnet34_preact
from .vgg16mod import vgg16mod as _vgg16mod
from .transforms import PatchClassification
@@ -233,8 +233,6 @@ def get_sha256_of_weights(self) -> str:

# Container for all models we can use that are not in timm.
_model_registry: Dict[str, Callable[[int], torch.nn.Module]] = {
"inceptionv4": _inceptionv4,
"inceptionv4nobn": _inceptionv4_no_bn,
"preactresnet34": _resnet34_preact,
"vgg16mod": _vgg16mod,
}
2 changes: 1 addition & 1 deletion wsinfer/modeldefs/inceptionv4_tcga-brca-v1.yaml
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
version: "1.0"
# The models are referenced by the pair of [architecture, weights], so this pair must
# be unique.
architecture: inceptionv4 # Must be a string.
architecture: inception_v4 # Must be a string.
name: TCGA-BRCA-v1 # Must be a string.
# Where to get the model weights. Either a URL or path to a file.
# If using a URL, set the url_file_name (the name of the file when it is downloaded).
2 changes: 1 addition & 1 deletion wsinfer/modeldefs/inceptionv4nobn_tcga-tils-v1.yaml
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@ version: "1.0"
# The models are referenced by the pair of [architecture, weights], so this pair must
# be unique.
# Inceptionv4 without batch normalization.
architecture: inceptionv4nobn # Must be a string.
architecture: inception_v4nobn # Must be a string.
name: TCGA-TILs-v1 # Must be a string.
# Where to get the model weights. Either a URL or path to a file.
# If using a URL, set the url_file_name (the name of the file when it is downloaded).

0 comments on commit 1683bba

Please sign in to comment.