diff --git a/environ/environ.py b/environ/environ.py index 67b7f7d8..0d944cc7 100644 --- a/environ/environ.py +++ b/environ/environ.py @@ -11,12 +11,15 @@ variables to configure your Django application. """ + import ast +import contextlib import itertools import logging import os import re import sys +from typing import Any, Optional, Union import warnings from urllib.parse import ( parse_qs, @@ -205,67 +208,55 @@ def __call__(self, var, cast=None, default=NOTSET, parse_default=False): def __contains__(self, var): return var in self.ENVIRON - def str(self, var, default=NOTSET, multiline=False): + def str(self, var, default: Union[NoValue, str] = NOTSET, multiline=False): """ :rtype: str """ value = self.get_value(var, cast=str, default=default) - if multiline: - return re.sub(r'(\\r)?\\n', r'\n', value) - return value + return re.sub(r'(\\r)?\\n', r'\n', value) if multiline else value def bytes(self, var, default=NOTSET, encoding='utf8'): """ :rtype: bytes """ value = self.get_value(var, cast=str, default=default) - if hasattr(value, 'encode'): - return value.encode(encoding) - return value + return value.encode(encoding) if hasattr(value, 'encode') else value - def bool(self, var, default=NOTSET): + def bool(self, var, default: Union[bool, NoValue] = NOTSET): """ :rtype: bool """ return self.get_value(var, cast=bool, default=default) - def int(self, var, default=NOTSET): + def int(self, var, default: Union[int, NoValue] =NOTSET): """ :rtype: int """ return self.get_value(var, cast=int, default=default) - def float(self, var, default=NOTSET): + def float(self, var, default: Union[float, NoValue] =NOTSET): """ :rtype: float """ return self.get_value(var, cast=float, default=default) - def json(self, var, default=NOTSET): + def json(self, var, default: Union[NoValue, str] =NOTSET): """ :returns: Json parsed """ return self.get_value(var, cast=json.loads, default=default) - def list(self, var, cast=None, default=NOTSET): + def list(self, var, cast=None, default: Union[list, NoValue] =NOTSET): """ :rtype: list """ - return self.get_value( - var, - cast=list if not cast else [cast], - default=default - ) + return self.get_value(var, cast=[cast] if cast else list, default=default) - def tuple(self, var, cast=None, default=NOTSET): + def tuple(self, var, cast=None, default: Union[tuple, NoValue] =NOTSET): """ :rtype: tuple """ - return self.get_value( - var, - cast=tuple if not cast else (cast,), - default=default - ) + return self.get_value(var, cast=(cast, ) if cast else tuple, default=default) def dict(self, var, cast=dict, default=NOTSET): """ @@ -273,7 +264,7 @@ def dict(self, var, cast=dict, default=NOTSET): """ return self.get_value(var, cast=cast, default=default) - def url(self, var, default=NOTSET): + def url(self, var, default: Union[str, NoValue] =NOTSET): """ :rtype: urllib.parse.ParseResult """ @@ -336,13 +327,13 @@ def search_url(self, var=DEFAULT_SEARCH_ENV, default=NOTSET, engine=None): engine=engine ) - def path(self, var, default=NOTSET, **kwargs): + def path(self, var, default: str = NOTSET, **kwargs): """ :rtype: Path """ - return Path(self.get_value(var, default=default), **kwargs) + return Path(str(self.get_value(var, default=default)), **kwargs) - def get_value(self, var, cast=None, default=NOTSET, parse_default=False): + def get_value(self, var, cast: Optional[str] = None, default: Any = NOTSET, parse_default=False): """Return value for given environment variable. :param str var: @@ -361,7 +352,7 @@ def get_value(self, var, cast=None, default=NOTSET, parse_default=False): "get '%s' casted as '%s' with default '%s'", var, cast, default) - var_name = f'{self.prefix}{var}' + var_name = f"{self.prefix}{var}" if var_name in self.scheme: var_info = self.scheme[var_name] @@ -375,19 +366,16 @@ def get_value(self, var, cast=None, default=NOTSET, parse_default=False): cast = var_info[0] if default is self.NOTSET: - try: + with contextlib.suppress(IndexError): default = var_info[1] - except IndexError: - pass - else: - if not cast: - cast = var_info + elif not cast: + cast = var_info try: - value = self.ENVIRON[var_name] + value: Any = self.ENVIRON[var_name] except KeyError as exc: if default is self.NOTSET: - error_msg = f'Set the {var} environment variable' + error_msg = f"Set the {var} environment variable" raise ImproperlyConfigured(error_msg) from exc value = default @@ -403,10 +391,13 @@ def get_value(self, var, cast=None, default=NOTSET, parse_default=False): value = value.replace(escape, prefix) # Smart casting - if self.smart_cast: - if cast is None and default is not None and \ - not isinstance(default, NoValue): - cast = type(default) + if ( + self.smart_cast + and cast is None + and default is not None + and not isinstance(default, NoValue) + ): + cast = type(default) value = None if default is None and value == '' else value @@ -457,7 +448,7 @@ def parse_value(cls, value, cast): elif cast is tuple: val = value.strip('(').strip(')').split(',') # pylint: disable=consider-using-generator - value = tuple([x for x in val if x]) + value = tuple(x for x in val if x) elif cast is float: # clean string float_str = re.sub(r'[^\d,.-]', '', value) @@ -467,7 +458,7 @@ def parse_value(cls, value, cast): if len(parts) == 1: float_str = parts[0] else: - float_str = f"{''.join(parts[0:-1])}.{parts[-1]}" + float_str = f"{''.join(parts[:-1])}.{parts[-1]}" value = float(float_str) else: value = cast(value) @@ -581,21 +572,17 @@ def db_url_config(cls, url, engine=None): config_options = {} for k, v in parse_qs(url.query).items(): if k.upper() in cls._DB_BASE_OPTIONS: - config.update({k.upper(): _cast(v[0])}) + config[k.upper()] = _cast(v[0]) else: - config_options.update({k: _cast_int(v[0])}) + config_options[k] = _cast_int(v[0]) config['OPTIONS'] = config_options - if engine: - config['ENGINE'] = engine - else: - config['ENGINE'] = url.scheme - + config['ENGINE'] = engine or url.scheme if config['ENGINE'] in cls.DB_SCHEMES: config['ENGINE'] = cls.DB_SCHEMES[config['ENGINE']] if not config.get('ENGINE', False): - warnings.warn(f'Engine not recognized from url: {config}') + warnings.warn(f"Engine not recognized from url: {config}") return {} return config @@ -630,9 +617,7 @@ def cache_url_config(cls, url, backend=None): # Add the drive to LOCATION if url.scheme == 'filecache': - config.update({ - 'LOCATION': url.netloc + url.path, - }) + config['LOCATION'] = url.netloc + url.path # urlparse('pymemcache://127.0.0.1:11211') # => netloc='127.0.0.1:11211', path='' @@ -643,21 +628,11 @@ def cache_url_config(cls, url, backend=None): # urlparse('memcache:///tmp/memcached.sock') # => netloc='', path='/tmp/memcached.sock' if not url.netloc and url.scheme in ['memcache', 'pymemcache']: - config.update({ - 'LOCATION': 'unix:' + url.path, - }) + config['LOCATION'] = f'unix:{url.path}' elif url.scheme.startswith('redis'): - if url.hostname: - scheme = url.scheme.replace('cache', '') - else: - scheme = 'unix' - locations = [scheme + '://' + loc + url.path - for loc in url.netloc.split(',')] - if len(locations) == 1: - config['LOCATION'] = locations[0] - else: - config['LOCATION'] = locations - + scheme = url.scheme.replace('cache', '') if url.hostname else 'unix' + locations = [f'{scheme}://{loc}{url.path}' for loc in url.netloc.split(',')] + config['LOCATION'] = locations[0] if len(locations) == 1 else locations if url.query: config_options = {} for k, v in parse_qs(url.query).items(): @@ -687,7 +662,7 @@ def email_url_config(cls, url, backend=None): config = {} - url = urlparse(url) if not isinstance(url, cls.URL_CLASS) else url + url = url if isinstance(url, cls.URL_CLASS) else urlparse(url) # Remove query strings path = url.path[1:] @@ -738,9 +713,7 @@ def search_url_config(cls, url, engine=None): :rtype: dict """ - config = {} - - url = urlparse(url) if not isinstance(url, cls.URL_CLASS) else url + url = url if isinstance(url, cls.URL_CLASS) else urlparse(url) # Remove query strings. path = url.path[1:] @@ -756,7 +729,7 @@ def search_url_config(cls, url, engine=None): params = parse_qs(url.query) if 'EXCLUDED_INDEXES' in params: config['EXCLUDED_INDEXES'] \ - = params['EXCLUDED_INDEXES'][0].split(',') + = params['EXCLUDED_INDEXES'][0].split(',') if 'INCLUDE_SPELLING' in params: config['INCLUDE_SPELLING'] = cls.parse_value( params['INCLUDE_SPELLING'][0], @@ -770,9 +743,11 @@ def search_url_config(cls, url, engine=None): if url.scheme == 'simple': return config - if url.scheme in ['solr'] + cls.ELASTICSEARCH_FAMILY: - if 'KWARGS' in params: - config['KWARGS'] = params['KWARGS'][0] + if ( + url.scheme in ['solr'] + cls.ELASTICSEARCH_FAMILY + and 'KWARGS' in params + ): + config['KWARGS'] = params['KWARGS'][0] # remove trailing slash if path.endswith('/'): @@ -804,7 +779,7 @@ def search_url_config(cls, url, engine=None): config['INDEX_NAME'] = index return config - config['PATH'] = '/' + path + config['PATH'] = f'/{path}' if url.scheme == 'whoosh': if 'STORAGE' in params: @@ -882,21 +857,18 @@ def read_env(cls, env_file=None, overwrite=False, encoding='utf8', def _keep_escaped_format_characters(match): """Keep escaped newline/tabs in quoted strings""" escaped_char = match.group(1) - if escaped_char in 'rnt': - return '\\' + escaped_char - return escaped_char + return '\\' + escaped_char if escaped_char in 'rnt' else escaped_char for line in content.splitlines(): m1 = re.match(r'\A(?:export )?([A-Za-z_0-9]+)=(.*)\Z', line) if m1: - key, val = m1.group(1), m1.group(2) + key, val = m1[1], m1[2] m2 = re.match(r"\A'(.*)'\Z", val) if m2: - val = m2.group(1) + val = m2[1] m3 = re.match(r'\A"(.*)"\Z', val) if m3: - val = re.sub(r'\\(.)', _keep_escaped_format_characters, - m3.group(1)) + val = re.sub(r'\\(.)', _keep_escaped_format_characters, m3[1]) overrides[key] = str(val) elif not line or line.startswith('#'): # ignore warnings for empty line-breaks or comments @@ -992,9 +964,11 @@ def __ne__(self, other): return not self.__eq__(other) def __add__(self, other): - if not isinstance(other, Path): - return Path(self.__root__, other) - return Path(self.__root__, other.__root__) + return ( + Path(self.__root__, other.__root__) + if isinstance(other, Path) + else Path(self.__root__, other) + ) def __sub__(self, other): if isinstance(other, int): @@ -1019,7 +993,7 @@ def __contains__(self, item): return item.__root__.startswith(base_path) def __repr__(self): - return f'' + return f"" def __str__(self): return self.__root__