Skip to content

Commit

Permalink
Adds support for google repos stored on GCS (#800)
Browse files Browse the repository at this point in the history
* added support for google repository

* added more tests

* increased codecov

* refactored some code
  • Loading branch information
AbhinavTuli authored Apr 29, 2021
1 parent e161832 commit fba7d53
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 14 deletions.
45 changes: 44 additions & 1 deletion hub/api/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from hub import load, transform
from hub.api.dataset_utils import slice_extract_info, slice_split, check_class_label
from hub.cli.auth import login_fn
from hub.exceptions import DirectoryNotEmptyException, SchemaMismatchException
from hub.exceptions import (
DirectoryNotEmptyException,
SchemaMismatchException,
ReadModeException,
)
from hub.schema import BBox, ClassLabel, Image, SchemaDict, Sequence, Tensor, Text
from hub.utils import (
azure_creds_exist,
Expand Down Expand Up @@ -1291,6 +1295,45 @@ def my_filter(sample):
assert ds3["abc", i].compute() == 5 * i


def test_dataset_google():
ds = Dataset("google/bike")
assert ds["image_channels", 0].compute() == 3
with pytest.raises(ReadModeException):
ds["image_channels", 0] = 3
ds = Dataset("google/bottle")
assert ds["image_channels", 0].compute() == 3
with pytest.raises(ReadModeException):
ds["image_channels", 0] = 3
ds = Dataset("google/book")
assert ds["image_channels", 0].compute() == 3
with pytest.raises(ReadModeException):
ds["image_channels", 0] = 3
ds = Dataset("google/cereal_box")
assert ds["image_channels", 0].compute() == 3
with pytest.raises(ReadModeException):
ds["image_channels", 0] = 3
ds = Dataset("google/chair")
assert ds["image_channels", 0].compute() == 3
with pytest.raises(ReadModeException):
ds["image_channels", 0] = 3
ds = Dataset("google/cup")
assert ds["image_channels", 0].compute() == 3
with pytest.raises(ReadModeException):
ds["image_channels", 0] = 3
ds = Dataset("google/camera")
assert ds["image_channels", 0].compute() == 3
with pytest.raises(ReadModeException):
ds["image_channels", 0] = 3
ds = Dataset("google/laptop")
assert ds["image_channels", 0].compute() == 3
with pytest.raises(ReadModeException):
ds["image_channels", 0] = 3
ds = Dataset("google/shoe")
assert ds["image_channels", 0].compute() == 3
with pytest.raises(ReadModeException):
ds["image_channels", 0] = 3


if __name__ == "__main__":
# test_dataset_assign_value()
# test_dataset_setting_shape()
Expand Down
1 change: 0 additions & 1 deletion hub/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def request(
or headers["Authorization"] != self.auth_header
):
headers["Authorization"] = self.auth_header

try:
logger.debug(f"Sending: Headers {headers}, Json: {json}")

Expand Down
15 changes: 14 additions & 1 deletion hub/client/hub_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def get_dataset_path(self, tag):
return dataset

def get_credentials(self):

if self.auth_header is None:
token = AuthClient().get_access_token(username="public", password="")
self.auth_header = f"Bearer {token}"
Expand All @@ -66,6 +65,20 @@ def get_credentials(self):
self.save_config(details)
return details

def get_dataset_credentials(self, org_id, ds_name):
self.auth_header = (
self.auth_header
if self.auth_header is not None
else f'Bearer {AuthClient().get_access_token(username="public", password="")}'
)
relative_url = config.GET_DATASET_CREDENTIALS_SUFFIX % (org_id, ds_name)
r = self.request(
"GET",
relative_url,
endpoint=config.HUB_REST_ENDPOINT,
).json()
return r["creds"], r["path"]

def get_config(self, reset=False):
if not os.path.isfile(config.STORE_CONFIG_PATH) or self.auth_header is None:
self.get_credentials()
Expand Down
1 change: 1 addition & 0 deletions hub/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
GET_REGISTER_SUFFIX = "/api/user/register"
GET_DATASET_SUFFIX = "/api/dataset/get"
GET_DATASET_PATH_SUFFIX = "/api/dataset/get/path"
GET_DATASET_CREDENTIALS_SUFFIX = "/api/org/%s/ds/%s/creds"

CREATE_DATASET_SUFFIX = "/api/dataset/create"
UPDATE_STATE_SUFFIX = "/api/dataset/state"
Expand Down
28 changes: 17 additions & 11 deletions hub/store/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,23 @@ def get_fs_and_path(
return fsspec.filesystem("file"), url
else:
# TOOD check if url is username/dataset:version
url, creds = _connect(url, public=public)
fs = S3FileSystemReplacement(
expiration=creds["expiration"],
key=creds["access_key"],
secret=creds["secret_key"],
token=creds["session_token"],
client_kwargs={
"endpoint_url": creds["endpoint"],
"region_name": creds["region"],
},
)
if url.split("/")[0] == "google":
org_id, ds_name = url.split("/")
token, url = HubControlClient().get_dataset_credentials(org_id, ds_name)
fs = gcsfs.GCSFileSystem(token=token)
url = url[6:]
else:
url, creds = _connect(url, public=public)
fs = S3FileSystemReplacement(
expiration=creds["expiration"],
key=creds["access_key"],
secret=creds["secret_key"],
token=creds["session_token"],
client_kwargs={
"endpoint_url": creds["endpoint"],
"region_name": creds["region"],
},
)
return (fs, url)


Expand Down

0 comments on commit fba7d53

Please sign in to comment.