Skip to content

Commit

Permalink
get dtype by global policy
Browse files Browse the repository at this point in the history
  • Loading branch information
fsx950223 committed Nov 20, 2020
1 parent 902e157 commit 4c83fbe
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
5 changes: 2 additions & 3 deletions efficientdet/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,8 @@ def dataset_parser(self, value, example_decoder, anchor_labeler, params):
classes = pad_to_fixed_size(classes, -1,
[self._max_instances_per_image, 1])
if params['mixed_precision']:
precision = utils.get_precision(params['strategy'],
params['mixed_precision'])
dtype = precision.split('_')[-1]
dtype = (
tf.keras.mixed_precision.experimental.global_policy().compute_dtype)
image = tf.cast(image, dtype=dtype)
box_targets = tf.nest.map_structure(
lambda box_target: tf.cast(box_target, dtype=dtype), box_targets)
Expand Down
4 changes: 1 addition & 3 deletions efficientdet/keras/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,9 +624,7 @@ class and box losses from all levels.
"""
# Sum all positives in a batch for normalization and avoid zero
# num_positives_sum, which would lead to inf loss during training
precision = utils.get_precision(self.config.strategy,
self.config.mixed_precision)
dtype = precision.split('_')[-1]
dtype = tf.keras.mixed_precision.experimental.global_policy().compute_dtype
num_positives_sum = tf.reduce_sum(labels['mean_num_positives']) + 1.0
positives_momentum = self.config.positives_momentum or 0
if positives_momentum > 0:
Expand Down

0 comments on commit 4c83fbe

Please sign in to comment.