diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000..dbc5734 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,8 @@ +[tool.isort] +profile = "black" +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 0 +use_parentheses = True +ensure_newline_before_comments = True +line_length = 119 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a0d4138..a138ce1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,12 @@ exclude: _pb2\.py$ repos: +- repo: https://github.com/pre-commit/mirrors-isort + rev: f0001b2 # Use the revision sha / tag you want to point at + hooks: + - id: isort + args: ["--profile", "black"] - repo: https://github.com/psf/black - rev: 19.10b0 + rev: 20.8b1 hooks: - id: black - repo: https://github.com/asottile/yesqa @@ -26,11 +31,19 @@ repos: - id: trailing-whitespace - id: flake8 - id: requirements-txt-fixer +- repo: https://github.com/pre-commit/mirrors-pylint + rev: d230ffd + hooks: + - id: pylint + args: + - --max-line-length=119 + - --ignore-imports=yes + - -d duplicate-code - repo: https://github.com/asottile/pyupgrade - rev: v1.13.0 + rev: v2.7.3 hooks: - id: pyupgrade - args: ['--py36-plus'] + args: ['--py37-plus'] - repo: https://github.com/pre-commit/pygrep-hooks rev: v1.5.1 hooks: diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..259d22b --- /dev/null +++ b/.pylintrc @@ -0,0 +1,598 @@ +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-whitelist= + +# Specify a score threshold to be exceeded before program exits with error. +fail-under=10.0 + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS + +# Add files or directories matching the regex patterns to the blacklist. The +# regex matches against base names, not paths. +ignore-patterns= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=0 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=print-statement, + parameter-unpacking, + unpacking-in-except, + old-raise-syntax, + backtick, + long-suffix, + old-ne-operator, + old-octal-literal, + import-star-module-level, + non-ascii-bytes-literal, + raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + apply-builtin, + basestring-builtin, + buffer-builtin, + cmp-builtin, + coerce-builtin, + execfile-builtin, + file-builtin, + long-builtin, + raw_input-builtin, + reduce-builtin, + standarderror-builtin, + unicode-builtin, + xrange-builtin, + coerce-method, + delslice-method, + getslice-method, + setslice-method, + no-absolute-import, + old-division, + dict-iter-method, + dict-view-method, + next-method-called, + metaclass-assignment, + indexing-exception, + raising-string, + reload-builtin, + oct-method, + hex-method, + nonzero-method, + cmp-method, + input-builtin, + round-builtin, + intern-builtin, + unichr-builtin, + map-builtin-not-iterating, + zip-builtin-not-iterating, + range-builtin-not-iterating, + filter-builtin-not-iterating, + using-cmp-argument, + eq-without-hash, + div-method, + idiv-method, + rdiv-method, + exception-message-attribute, + invalid-str-codec, + sys-max-int, + bad-python3-import, + deprecated-string-function, + deprecated-str-translate-call, + deprecated-itertools-function, + deprecated-types-field, + next-method-defined, + dict-items-not-iterating, + dict-keys-not-iterating, + dict-values-not-iterating, + deprecated-operator-function, + deprecated-urllib-function, + xreadlines-attribute, + deprecated-sys-function, + exception-escape, + comprehension-escape, + missing-module-docstring, + invalid-name, + missing-function-docstring, + redefined-outer-name, + import-error, + missing-class-docstring, + too-few-public-methods, + attribute-defined-outside-init, + too-many-locals, + too-many-arguments + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'error', 'warning', 'refactor', and 'convention' +# which contain the number of messages in each category, as well as 'statement' +# which is the total number of statements analyzed. This score is used by the +# global evaluation report (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=119 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +#notes-rgx= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. +#class-attribute-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. +#variable-rgx= + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[DESIGN] + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled). +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled). +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "BaseException, Exception". +overgeneral-exceptions=BaseException, + Exception diff --git a/README.md b/README.md index 136d450..0775e00 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ IT added a set of functionality: * Syncronized BatchNorm * Support for various loggers like [W&B](https://www.wandb.com/) or [Neptune.ml](https://neptune.ai/) -### Hyperparameters are fedined in config file +### Hyperparameters are defined in the config file Hyperparameters that were scattered across the code moved to the config at [retinadace/config](retinadace/config) @@ -103,6 +103,20 @@ You can convert the default labels of the WiderFaces to the json of the propper ## Training +### Define config +Example configs could be found at [retinaface/configs](retinaface/configs) + +### Define environmental variables + +```bash +export TRAIN_IMAGE_PATH= +export VAL_IMAGE_PATH= +export TRAIN_LABEL_PATH= +export VAL_LABEL_PATH= +``` + +### Run training script + ``` python retinaface/train.py -h usage: train.py [-h] -c CONFIG_PATH diff --git a/requirements.txt b/requirements.txt index f3db32f..19a0380 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,4 @@ albumentations iglovikov_helper_functions numpy pillow -streamlit torch diff --git a/retinaface/box_utils.py b/retinaface/box_utils.py index cd13e0b..6962516 100644 --- a/retinaface/box_utils.py +++ b/retinaface/box_utils.py @@ -5,7 +5,8 @@ def point_form(boxes: torch.Tensor) -> torch.Tensor: - """Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison to point form ground truth data. + """Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison + to point form ground truth data. Args: boxes: center-size default boxes from priorbox layers. @@ -26,7 +27,7 @@ def center_size(boxes: torch.Tensor) -> torch.Tensor: def intersect(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor: - """ We resize both tensors to [A,B,2] without new malloc: + """We resize both tensors to [A,B,2] without new malloc: [A, 2] -> [A, 1, 2] -> [A, B, 2] [B, 2] -> [1, B, 2] -> [A, B, 2] Then we compute the area of intersect between box_a and box_b. @@ -125,14 +126,14 @@ def match( best_prior_idx_filter.squeeze_(1) best_prior_overlap.squeeze_(1) best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior - # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes best_truth_idx[best_prior_idx[j]] = j - matches = box_gt[best_truth_idx] # Shape: [num_priors, 4] 此处为每一个anchor对应的bbox取出来 - labels = labels_gt[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来 - labels[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本 + matches = box_gt[best_truth_idx] # Shape: [num_priors, 4] + labels = labels_gt[best_truth_idx] # Shape: [num_priors] + labels[best_truth_overlap < threshold] = 0 # label as background overlap<0.35 loc = encode(matches, priors, variances) matches_landm = landmarks_gt[best_truth_idx] @@ -209,7 +210,6 @@ def decode( Return: decoded bounding box predictions """ - boxes = torch.cat( ( priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], diff --git a/retinaface/configs/2020-07-20.yaml b/retinaface/configs/2020-07-20.yaml index b63ccd8..f66a2e3 100644 --- a/retinaface/configs/2020-07-20.yaml +++ b/retinaface/configs/2020-07-20.yaml @@ -33,7 +33,6 @@ optimizer: trainer: type: pytorch_lightning.Trainer - early_stop_callback: False gpus: 4 use_amp: True amp_level: O1 @@ -43,6 +42,7 @@ trainer: progress_bar_refresh_rate: 1 benchmark: True precision: 16 + sync_batchnorm: True scheduler: type: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts diff --git a/retinaface/configs/2020-11-15.yaml b/retinaface/configs/2020-11-15.yaml new file mode 100644 index 0000000..0503407 --- /dev/null +++ b/retinaface/configs/2020-11-15.yaml @@ -0,0 +1,163 @@ +--- +seed: 1984 + +num_workers: 4 +experiment_name: "2020-11-15" + +num_classes: 2 + +model: + type: retinaface.network.RetinaFace + name: Resnet50 + pretrained: True + return_layers: {"layer2": 1, "layer3": 2, "layer4": 3} + in_channels: 256 + out_channels: 256 + +optimizer: + type: torch.optim.SGD + lr: 0.001 + weight_decay: 0.0001 + momentum: 0.9 + +trainer: + type: pytorch_lightning.Trainer + gpus: 4 + amp_level: O1 + max_epochs: 150 + distributed_backend: ddp + num_sanity_val_steps: 1 + progress_bar_refresh_rate: 1 + benchmark: True + precision: 16 + sync_batchnorm: True + +scheduler: + type: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts + T_0: 10 + T_mult: 2 + +train_parameters: + batch_size: 6 + rotate90: True + +checkpoint_callback: + type: pytorch_lightning.callbacks.ModelCheckpoint + filepath: "2020-11-15" + monitor: val_loss + verbose: True + mode: max + save_top_k: -1 + +val_parameters: + batch_size: 10 + iou_threshold: 0.4 + rotate90: True + box_min_size: 5 + +loss: + type: retinaface.multibox_loss.MultiBoxLoss + num_classes: 2 + overlap_thresh: 0.35 + prior_for_matching: True + bkg_label: 0 + neg_mining: True + neg_pos: 7 + neg_overlap: 0.35 + encode_target: False + +prior_box: + type: retinaface.prior_box.priorbox + min_sizes: [[16, 32], [64, 128], [256, 512]] + steps: [8, 16, 32] + clip: False + +image_size: [1024, 1024] + +loss_weights: + localization: 2 + classification: 1 + landmarks: 1 + +test_parameters: + variance: [0.1, 0.2] + +train_aug: + transform: + __class_fullname__: albumentations.core.composition.Compose + bbox_params: null + keypoint_params: null + p: 1 + transforms: + - __class_fullname__: albumentations.augmentations.transforms.RandomBrightnessContrast + always_apply: false + brightness_limit: 0.2 + contrast_limit: [0.5, 1.5] + p: 0.5 + - __class_fullname__: albumentations.augmentations.transforms.HueSaturationValue + hue_shift_limit: 20 + val_shift_limit: 20 + p: 0.5 + - __class_fullname__: albumentations.augmentations.transforms.RandomGamma + gamma_limit: [80, 120] + p: 0.5 + - __class_fullname__: albumentations.augmentations.transforms.Resize + height: 1024 + width: 1024 + p: 1 + - __class_fullname__: albumentations.augmentations.transforms.Normalize + always_apply: false + max_pixel_value: 255.0 + mean: + - 0.485 + - 0.456 + - 0.406 + p: 1 + std: + - 0.229 + - 0.224 + - 0.225 + +val_aug: + transform: + __class_fullname__: albumentations.core.composition.Compose + bbox_params: null + keypoint_params: null + p: 1 + transforms: + - __class_fullname__: albumentations.augmentations.transforms.Resize + height: 1024 + width: 1024 + p: 1 + - __class_fullname__: albumentations.augmentations.transforms.Normalize + always_apply: false + max_pixel_value: 255.0 + mean: + - 0.485 + - 0.456 + - 0.406 + p: 1 + std: + - 0.229 + - 0.224 + - 0.225 + +test_aug: + transform: + __class_fullname__: albumentations.core.composition.Compose + bbox_params: null + keypoint_params: null + p: 1 + transforms: + - __class_fullname__: albumentations.augmentations.transforms.Normalize + always_apply: false + max_pixel_value: 255.0 + mean: + - 0.485 + - 0.456 + - 0.406 + p: 1 + std: + - 0.229 + - 0.224 + - 0.225 diff --git a/retinaface/data_augment.py b/retinaface/data_augment.py index d6b4c0c..0d2b4dc 100644 --- a/retinaface/data_augment.py +++ b/retinaface/data_augment.py @@ -6,19 +6,20 @@ from retinaface.box_utils import matrix_iof -def _crop( +def random_crop( image: np.ndarray, boxes: np.ndarray, labels: np.ndarray, landm: np.ndarray, img_dim: int ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, bool]: + """ + if random.uniform(0, 1) <= 0.2: + scale = 1.0 + else: + scale = random.uniform(0.3, 1.0) + """ height, width = image.shape[:2] pad_image_flag = True for _ in range(250): - """ - if random.uniform(0, 1) <= 0.2: - scale = 1.0 - else: - scale = random.uniform(0.3, 1.0) - """ + pre_scales = [0.3, 0.45, 0.6, 0.8, 1.0] scale = random.choice(pre_scales) short_side = min(width, height) @@ -80,7 +81,9 @@ def _crop( return image, boxes, labels, landm, pad_image_flag -def _mirror(image: np.ndarray, boxes: np.ndarray, landms: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: +def random_horizontal_flip( + image: np.ndarray, boxes: np.ndarray, landms: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: width = image.shape[1] if random.randrange(2): image = image[:, ::-1] @@ -124,10 +127,12 @@ def __call__(self, image: np.ndarray, targets: np.ndarray) -> Tuple[np.ndarray, landmarks = targets[:, 4:-1].copy() labels = targets[:, -1:].copy() - image_t, boxes_t, labels_t, landmarks_t, pad_image_flag = _crop(image, boxes, labels, landmarks, self.img_dim) + image_t, boxes_t, labels_t, landmarks_t, pad_image_flag = random_crop( + image, boxes, labels, landmarks, self.img_dim + ) image_t = _pad_to_square(image_t, pad_image_flag) - image_t, boxes_t, landmarks_t = _mirror(image_t, boxes_t, landmarks_t) + image_t, boxes_t, landmarks_t = random_horizontal_flip(image_t, boxes_t, landmarks_t) height, width = image_t.shape[:2] boxes_t[:, 0::2] = boxes_t[:, 0::2] / width diff --git a/retinaface/dataset.py b/retinaface/dataset.py index 587fbaa..8c60b03 100644 --- a/retinaface/dataset.py +++ b/retinaface/dataset.py @@ -1,12 +1,12 @@ import json from pathlib import Path -from typing import Dict, Any, List, Tuple, Optional +from typing import Any, Dict, List, Tuple import albumentations as albu import numpy as np import torch +from iglovikov_helper_functions.dl.pytorch.utils import tensor_from_rgb_image from iglovikov_helper_functions.utils.image_utils import load_rgb -from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image from torch.utils import data from retinaface.data_augment import Preproc @@ -15,26 +15,23 @@ class FaceDetectionDataset(data.Dataset): def __init__( self, - label_path: str, - image_path: Optional[str], + label_path: Path, + image_path: Path, transform: albu.Compose, preproc: Preproc, rotate90: bool = False, ) -> None: self.preproc = preproc - if image_path is None: - self.image_path = image_path - else: - self.image_path = Path(image_path) + self.image_path = Path(image_path) self.transform = transform self.rotate90 = rotate90 with open(label_path) as f: - self.labels = json.load(f) + labels = json.load(f) - self.labels = [x for x in self.labels if Path(x["file_path"]).exists()] + self.labels = [x for x in labels if (image_path / x["file_name"]).exists()] def __len__(self) -> int: return len(self.labels) @@ -44,26 +41,26 @@ def __getitem__(self, index: int) -> Dict[str, Any]: file_name = labels["file_name"] - if self.image_path is None: - image = load_rgb(labels["file_path"]) - else: - image = load_rgb(self.image_path / file_name) + image = load_rgb(self.image_path / file_name) + + image_height, image_width = image.shape[:2] # annotations will have the format # 4: box, 10 landmarks, 1: landmarks / no landmarks num_annotations = 4 + 10 + 1 annotations = np.zeros((0, num_annotations)) - image_height, image_width = image.shape[:2] - for label in labels["annotations"]: annotation = np.zeros((1, num_annotations)) + x_min, y_min, x_max, y_max = label["bbox"] - annotation[0, 0] = np.clip(x_min, 0, image_width - 1) - annotation[0, 1] = np.clip(y_min, 0, image_height - 1) - annotation[0, 2] = np.clip(x_max, x_min + 1, image_width - 1) - annotation[0, 3] = np.clip(y_max, y_min + 1, image_height - 1) + x_min = np.clip(x_min, 0, image_width - 1) + y_min = np.clip(y_min, 0, image_height - 1) + x_max = np.clip(x_max, x_min + 1, image_width - 1) + y_max = np.clip(y_max, y_min, image_height - 1) + + annotation[0, :4] = x_min, y_min, x_max, y_max if "landmarks" in label and label["landmarks"]: landmarks = np.array(label["landmarks"]) diff --git a/retinaface/inference.py b/retinaface/inference.py index 5c0d061..716f57d 100644 --- a/retinaface/inference.py +++ b/retinaface/inference.py @@ -1,7 +1,7 @@ import argparse import json from pathlib import Path -from typing import Dict, List, Union, Optional, Any +from typing import Any, Dict, List, Optional, Union import albumentations as albu import cv2 @@ -11,10 +11,14 @@ import torch.utils.data import torch.utils.data.distributed import yaml -from PIL import Image, UnidentifiedImageError from albumentations.core.serialization import from_dict from iglovikov_helper_functions.config_parsing.utils import object_from_dict +from iglovikov_helper_functions.dl.pytorch.utils import ( + state_dict_from_disk, + tensor_from_rgb_image, +) from iglovikov_helper_functions.utils.image_utils import pad_to_size, unpad_from_size +from PIL import Image from torch.nn import functional as F from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler @@ -22,8 +26,7 @@ from tqdm import tqdm from retinaface.box_utils import decode, decode_landm -from retinaface.utils import load_checkpoint, vis_annotations -from retinaface.utils import tensor_from_rgb_image +from retinaface.utils import vis_annotations def get_args(): @@ -180,17 +183,6 @@ def process_predictions( return result -def check_if_image(file_list: List[Path]) -> List[Path]: - result: List[Path] = [] - for file_path in tqdm(file_list): - try: - Image.open(file_path) - except UnidentifiedImageError: - continue - result += [file_path] - return result - - def main(): args = get_args() torch.distributed.init_process_group(backend="nccl") @@ -230,17 +222,14 @@ def main(): model = model.half() corrections: Dict[str, str] = {"model.": ""} - checkpoint = load_checkpoint(file_path=args.weight_path, rename_in_layers=corrections) - model.load_state_dict(checkpoint["state_dict"]) + state_dict = state_dict_from_disk(file_path=args.weight_path, rename_in_layers=corrections) + model.load_state_dict(state_dict) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank ) - file_paths = [] - - for regexp in ["*"]: - file_paths += check_if_image([x for x in args.input_path.rglob(regexp)]) + file_paths = list(args.input_path.rglob("*.jpg")) dataset = InferenceDataset(file_paths, max_size=args.max_size, transform=from_dict(hparams["test_aug"])) diff --git a/retinaface/net.py b/retinaface/net.py index 7602ed0..98d8427 100644 --- a/retinaface/net.py +++ b/retinaface/net.py @@ -1,4 +1,4 @@ -from typing import List, Dict +from typing import Dict, List import torch import torch.nn.functional as F @@ -14,7 +14,10 @@ def conv_bn(inp: int, oup: int, stride: int = 1, leaky: float = 0) -> nn.Sequent def conv_bn_no_relu(inp: int, oup: int, stride: int) -> nn.Sequential: - return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),) + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) def conv_bn1X1(inp: int, oup: int, stride: int, leaky: float = 0) -> nn.Sequential: diff --git a/retinaface/predict_single.py b/retinaface/predict_single.py index 438f99a..c4986c0 100644 --- a/retinaface/predict_single.py +++ b/retinaface/predict_single.py @@ -6,6 +6,7 @@ import albumentations as A import numpy as np import torch +from iglovikov_helper_functions.dl.pytorch.utils import tensor_from_rgb_image from iglovikov_helper_functions.utils.image_utils import pad_to_size, unpad_from_size from torch.nn import functional as F from torchvision.ops import nms @@ -13,7 +14,6 @@ from retinaface.box_utils import decode, decode_landm from retinaface.network import RetinaFace from retinaface.prior_box import priorbox -from retinaface.utils import tensor_from_rgb_image class Model: diff --git a/retinaface/train.py b/retinaface/train.py index d22890c..a0d7dc7 100644 --- a/retinaface/train.py +++ b/retinaface/train.py @@ -1,25 +1,36 @@ import argparse +import os from collections import OrderedDict from pathlib import Path -from typing import List, Dict, Any +from typing import Any, Dict, List -import apex import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F import yaml +from addict import Dict as Adict from albumentations.core.serialization import from_dict from iglovikov_helper_functions.config_parsing.utils import object_from_dict from iglovikov_helper_functions.metrics.map import recall_precision -from pytorch_lightning.logging import WandbLogger +from pytorch_lightning.loggers import WandbLogger from torch.utils.data import DataLoader from torchvision.ops import nms from retinaface.box_utils import decode from retinaface.data_augment import Preproc from retinaface.dataset import FaceDetectionDataset, detection_collate -from retinaface.utils import load_checkpoint + +TRAIN_IMAGE_PATH = Path(os.environ["TRAIN_IMAGE_PATH"]) +VAL_IMAGE_PATH = Path(os.environ["VAL_IMAGE_PATH"]) + +TRAIN_LABEL_PATH = Path(os.environ["TRAIN_LABEL_PATH"]) +VAL_LABEL_PATH = Path(os.environ["VAL_LABEL_PATH"]) + +print("TRAIN_IMAGE_PATH = ", TRAIN_IMAGE_PATH) +print("VAL_IMAGE_PATH = ", VAL_IMAGE_PATH) +print("TRAIN_LABEL_PATH = ", TRAIN_LABEL_PATH) +print("VAL_LABEL_PATH = ", VAL_LABEL_PATH) def get_args(): @@ -30,76 +41,72 @@ def get_args(): class RetinaFace(pl.LightningModule): - def __init__(self, hparams: Dict[str, Any]): + def __init__(self, config): super().__init__() - self.hparams = hparams - - self.prior_box = object_from_dict(self.hparams["prior_box"], image_size=self.hparams["image_size"]) - self.model = object_from_dict(self.hparams["model"]) - corrections: Dict[str, str] = {"model.": ""} - - if "weights" in self.hparams: - checkpoint = load_checkpoint(file_path=self.hparams["weights"], rename_in_layers=corrections) - self.model.load_state_dict(checkpoint["state_dict"]) + self.config = config - if hparams["sync_bn"]: - self.model = apex.parallel.convert_syncbn_model(self.model) + self.prior_box = object_from_dict(self.config.prior_box, image_size=self.config.image_size) + self.model = object_from_dict(self.config.model) - self.loss_weights = self.hparams["loss_weights"] + self.loss_weights = self.config.loss_weights - self.loss = object_from_dict(self.hparams["loss"], priors=self.prior_box) + self.loss = object_from_dict(self.config.loss, priors=self.prior_box) - def setup(self, state: int = 0) -> None: - self.preproc = Preproc(img_dim=self.hparams["image_size"][0]) + def setup(self, state=0): # pylint: disable=W0613 + self.preproc = Preproc(img_dim=self.config.image_size[0]) - def forward(self, batch: torch.Tensor) -> torch.Tensor: + def forward(self, batch): return self.model(batch) def train_dataloader(self): - return DataLoader( + result = DataLoader( FaceDetectionDataset( - label_path=self.hparams["train_annotation_path"], - image_path=self.hparams["train_image_path"], - transform=from_dict(self.hparams["train_aug"]), + label_path=TRAIN_LABEL_PATH, + image_path=TRAIN_IMAGE_PATH, + transform=from_dict(self.config.train_aug), preproc=self.preproc, - rotate90=self.hparams["train_parameters"]["rotate90"], + rotate90=self.config.train_parameters.rotate90, ), - batch_size=self.hparams["train_parameters"]["batch_size"], - num_workers=self.hparams["num_workers"], + batch_size=self.config.train_parameters.batch_size, + num_workers=self.config.num_workers, shuffle=True, pin_memory=True, drop_last=False, collate_fn=detection_collate, ) + print("Len train dataloader = ", len(result)) + return result def val_dataloader(self): - return DataLoader( + result = DataLoader( FaceDetectionDataset( - label_path=self.hparams["val_annotation_path"], - image_path=self.hparams["val_image_path"], - transform=from_dict(self.hparams["val_aug"]), + label_path=VAL_LABEL_PATH, + image_path=VAL_IMAGE_PATH, + transform=from_dict(self.config.val_aug), preproc=self.preproc, - rotate90=self.hparams["train_parameters"]["rotate90"], + rotate90=self.config.val_parameters.rotate90, ), - batch_size=self.hparams["val_parameters"]["batch_size"], - num_workers=self.hparams["num_workers"], + batch_size=self.config.val_parameters.batch_size, + num_workers=self.config.num_workers, shuffle=False, pin_memory=True, drop_last=True, collate_fn=detection_collate, ) + print("Len val dataloader = ", len(result)) + return result def configure_optimizers(self): optimizer = object_from_dict( - self.hparams["optimizer"], params=[x for x in self.model.parameters() if x.requires_grad] + self.config.optimizer, params=[x for x in self.model.parameters() if x.requires_grad] ) - scheduler = object_from_dict(self.hparams["scheduler"], optimizer=optimizer) + scheduler = object_from_dict(self.config.scheduler, optimizer=optimizer) self.optimizers = [optimizer] return self.optimizers, [scheduler] - def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, Any]: + def training_step(self, batch, batch_idx): # pylint: disable=W0613 images = batch["image"] targets = batch["annotation"] @@ -113,27 +120,15 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[ + self.loss_weights["landmarks"] * loss_landmarks ) - logs = { - "classification": loss_classification, - "localization": loss_localization, - "landmarks": loss_landmarks, - "train_loss": total_loss, - "lr": self._get_current_lr(), - } - - return OrderedDict( - { - "loss": total_loss, - "progress_bar": { - "train_loss": total_loss, - "classification": loss_classification, - "localization": loss_localization, - }, - "log": logs, - } - ) + self.log("train_classification", loss_classification, on_step=True, on_epoch=True, logger=True, prog_bar=True) + self.log("train_localization", loss_localization, on_step=True, on_epoch=True, logger=True, prog_bar=True) + self.log("train_landmarks", loss_landmarks, on_step=True, on_epoch=True, logger=True, prog_bar=True) + self.log("train_loss", total_loss, on_step=True, on_epoch=True, logger=True, prog_bar=True) + self.log("lr", self._get_current_lr(), on_step=True, on_epoch=True, logger=True, prog_bar=True) + + return total_loss - def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, Any]: + def validation_step(self, batch, batch_idx): # pylint: disable=W0613 images = batch["image"] image_height = images.shape[2] @@ -155,7 +150,7 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dic for batch_id in range(batch_size): boxes = decode( - location.data[batch_id], self.prior_box.to(images.device), self.hparams["test_parameters"]["variance"] + location.data[batch_id], self.prior_box.to(images.device), self.config.test_parameters.variance ) scores = confidence[batch_id][:, 1] @@ -166,7 +161,7 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dic boxes *= scale # do NMS - keep = nms(boxes, scores, self.hparams["val_parameters"]["iou_threshold"]) + keep = nms(boxes, scores, self.config.val_parameters.iou_threshold) boxes = boxes[keep, :].cpu().numpy() if boxes.shape[0] == 0: @@ -215,7 +210,7 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dic return OrderedDict({"predictions": predictions_coco, "gt": gt_coco}) - def validation_epoch_end(self, outputs: List) -> Dict[str, Any]: + def validation_epoch_end(self, outputs: List) -> None: result_predictions: List[dict] = [] result_gt: List[dict] = [] @@ -225,11 +220,10 @@ def validation_epoch_end(self, outputs: List) -> Dict[str, Any]: _, _, average_precision = recall_precision(result_gt, result_predictions, 0.5) - logs = {"epoch": self.trainer.current_epoch, "mAP@0.5": average_precision} + self.log("epoch", self.trainer.current_epoch, on_step=False, on_epoch=True, logger=True) + self.log("val_loss", average_precision, on_step=False, on_epoch=True, logger=True) - return {"val_loss": average_precision, "log": logs} - - def _get_current_lr(self) -> torch.Tensor: + def _get_current_lr(self) -> torch.Tensor: # type: ignore lr = [x["lr"] for x in self.optimizers[0].param_groups][0] return torch.from_numpy(np.array([lr]))[0].to(self.device) @@ -238,16 +232,18 @@ def main(): args = get_args() with open(args.config_path) as f: - hparams = yaml.load(f, Loader=yaml.SafeLoader) + config = Adict(yaml.load(f, Loader=yaml.SafeLoader)) + + pl.trainer.seed_everything(config.seed) - pipeline = RetinaFace(hparams) + pipeline = RetinaFace(config) - Path(hparams["checkpoint_callback"]["filepath"]).mkdir(exist_ok=True, parents=True) + Path(config.checkpoint_callback.filepath).mkdir(exist_ok=True, parents=True) trainer = object_from_dict( - hparams["trainer"], - logger=WandbLogger(hparams["experiment_name"]), - checkpoint_callback=object_from_dict(hparams["checkpoint_callback"]), + config.trainer, + logger=WandbLogger(config.experiment_name), + checkpoint_callback=object_from_dict(config.checkpoint_callback), ) trainer.fit(pipeline) diff --git a/retinaface/utils.py b/retinaface/utils.py index 028096e..6f0f4a8 100644 --- a/retinaface/utils.py +++ b/retinaface/utils.py @@ -1,41 +1,7 @@ -import re -from pathlib import Path -from typing import Union, Optional, Any, Dict, List +from typing import Any, Dict, List import cv2 import numpy as np -import torch - - -def load_checkpoint(file_path: Union[Path, str], rename_in_layers: Optional[dict] = None) -> Dict[str, Any]: - """Loads PyTorch checkpoint, optionally renaming layer names. - Args: - file_path: path to the torch checkpoint. - rename_in_layers: {from_name: to_name} - ex: {"model.0.": "", - "model.": ""} - Returns: - """ - checkpoint = torch.load(file_path, map_location=lambda storage, loc: storage) - - if rename_in_layers is not None: - model_state_dict = checkpoint["state_dict"] - - result = {} - for key, value in model_state_dict.items(): - for key_r, value_r in rename_in_layers.items(): - key = re.sub(key_r, value_r, key) - - result[key] = value - - checkpoint["state_dict"] = result - - return checkpoint - - -def tensor_from_rgb_image(image: np.ndarray) -> torch.Tensor: - image = np.transpose(image, (2, 0, 1)) - return torch.from_numpy(image) def vis_annotations(image: np.ndarray, annotations: List[Dict[str, Any]]) -> np.ndarray: