Skip to content

Commit

Permalink
Merge pull request #628 from wrangleworks/yaml-CLoader
Browse files Browse the repository at this point in the history
Refactor YAML handling to use CSafeDumper and CSafeLoader for improved performance
  • Loading branch information
ChrisWRWX authored Feb 24, 2025
2 parents 52c941b + 0afa7ad commit e836b9b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
14 changes: 13 additions & 1 deletion wrangles/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
import requests as _requests
import numpy as _np
import time as _time
try:
from yaml import CSafeDumper as _YAMLDumper
except ImportError:
from yaml import SafeDumper as _YAMLDumper


def chatGPT(
data: any,
Expand All @@ -28,7 +33,14 @@ def chatGPT(
if len(data) == 1:
content = list(data.values())[0]
else:
content = _yaml.dump(data, indent=2, sort_keys=False, allow_unicode=True)
content = _yaml.dump(
data,
indent=2,
sort_keys=False,
allow_unicode=True,
Dumper=_YAMLDumper,
width=1000
)

settings_local = _copy.deepcopy(settings)
settings_local["messages"].append(
Expand Down
9 changes: 6 additions & 3 deletions wrangles/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
add_special_parameters as _add_special_parameters,
wildcard_expansion as _wildcard_expansion
)

try:
from yaml import CSafeLoader as _YamlLoader, CSafeDumper as _YAMLDumper
except ImportError:
from yaml import SafeLoader as _YamlLoader, SafeDumper as _YAMLDumper

_logging.getLogger().setLevel(_logging.INFO)

Expand Down Expand Up @@ -122,7 +125,7 @@ def _replace_templated_values(
and '\n' in replacement_value
):
try:
replacement_value = _yaml.safe_load(replacement_value)
replacement_value = _yaml.load(replacement_value, Loader=_YamlLoader)
except:
# Replacement wasn't YAML
pass
Expand Down Expand Up @@ -175,7 +178,7 @@ def _load_recipe(
if not isinstance(recipe, str):
try:
# If user passes in a pre-parsed recipe, convert back to YAML
recipe = _yaml.dump(recipe, sort_keys=False, allow_unicode=True)
recipe = _yaml.dump(recipe, sort_keys=False, Dumper=_YAMLDumper, allow_unicode=True)
except:
raise ValueError('Recipe passed in as an invalid type')

Expand Down
8 changes: 7 additions & 1 deletion wrangles/recipe_wrangles/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import pandas as _pd
from fractions import Fraction as _Fraction
import yaml as _yaml
try:
from yaml import CSafeLoader as _YAMLLoader, CSafeDumper as _YAMLDumper
except ImportError:
from yaml import SafeLoader as _YAMLLoader, SafeDumper as _YAMLDumper


def case(df: _pd.DataFrame, input: _Union[str, list], output: _Union[str, list] = None, case: str = 'lower') -> _pd.DataFrame:
Expand Down Expand Up @@ -374,7 +378,7 @@ def _load_with_fallback(value):
If no default, raise an error.
"""
try:
return _yaml.safe_load(value, **kwargs) or default
return _yaml.load(value, Loader=_YAMLLoader, **kwargs) or default
except:
if default != None:
return default
Expand Down Expand Up @@ -458,6 +462,7 @@ def to_yaml(
row,
sort_keys=sort_keys,
allow_unicode=allow_unicode,
Dumper=_YAMLDumper,
**kwargs
)
for row in df[input_columns].values
Expand All @@ -470,6 +475,7 @@ def to_yaml(
row,
sort_keys=sort_keys,
allow_unicode=allow_unicode,
Dumper=_YAMLDumper,
**kwargs
)
for row in df[input].to_dict(orient="records")
Expand Down

0 comments on commit e836b9b

Please sign in to comment.