Skip to content

Commit

Permalink
Optimize Net._get_next_net_name (pytorch#107479)
Browse files Browse the repository at this point in the history
Summary: This is surprisingly expensive and can be easily optimized.

Differential Revision: D48440000

Pull Request resolved: pytorch#107479
Approved by: https://github.com/kit1980
  • Loading branch information
jeffdunn authored and pytorchmergebot committed Aug 22, 2023
1 parent 24147a8 commit 1e9b590
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions caffe2/python/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import namedtuple, OrderedDict, defaultdict
from past.builtins import basestring
from itertools import chain
from typing import Dict

from caffe2.proto import caffe2_pb2
from caffe2.python import scope, utils, workspace
Expand Down Expand Up @@ -1445,7 +1446,7 @@ def _recover_record_by_prefix(names, prefix=''):


class Net:
_net_names_used = set()
_net_names_used_counters: Dict[str, int] = {}
operator_registry_ = {}

@staticmethod
Expand All @@ -1454,17 +1455,16 @@ def current_prefix():
builder = NetBuilder.current(required=False)
return builder.name if builder else ''

@staticmethod
def _reset_used_names() -> None:
Net._net_names_used_counters = {}

@staticmethod
def _get_next_net_name(basename):
name = basename = '/'.join(
x for x in [Net.current_prefix(), basename] if x
)
next_idx = 1
while name in Net._net_names_used:
name = basename + '_' + str(next_idx)
next_idx += 1
Net._net_names_used |= set([name])
return name
basename = "/".join(x for x in [Net.current_prefix(), basename] if x)
next_idx = Net._net_names_used_counters.get(basename, 0)
Net._net_names_used_counters[basename] = next_idx + 1
return basename if next_idx == 0 else f"{basename}_{next_idx}"

def __init__(self, name_or_proto, inplace=False):
"""
Expand Down

0 comments on commit 1e9b590

Please sign in to comment.