forked from msminhas93/DeepLabv3FineTuning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
21 lines (17 loc) · 736 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
""" DeepLabv3 Model download and change the head for your prediction"""
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision import models
def createDeepLabv3(outputchannels=1):
"""DeepLabv3 class with custom head
Args:
outputchannels (int, optional): The number of output channels
in your dataset masks. Defaults to 1.
Returns:
model: Returns the DeepLabv3 model with the ResNet101 backbone.
"""
model = models.segmentation.deeplabv3_resnet101(pretrained=True,
progress=True)
model.classifier = DeepLabHead(2048, outputchannels)
# Set the model in training mode
model.train()
return model