Skip to content
Merged
98 changes: 42 additions & 56 deletions mergin/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
proxy_config=None,
):
self.url = url if url is not None else MerginClient.default_url()
self._auth_params = None
self._auth_params = {}
self._auth_session = None
self._user_info = None
self._server_type = None
Expand Down Expand Up @@ -192,36 +192,32 @@ def user_agent_info(self):
system_version = platform.mac_ver()[0]
return f"{self.client_version} ({platform.system()}/{system_version})"

def _check_token(f):
"""Wrapper for creating/renewing authorization token."""

def wrapper(self, *args):
if self._auth_params:
if self._auth_session:
# Refresh auth token if it expired or will expire very soon
delta = self._auth_session["expire"] - datetime.now(timezone.utc)
if delta.total_seconds() < 5:
self.log.info("Token has expired - refreshing...")
if self._auth_params.get("login", None) and self._auth_params.get("password", None):
self.log.info("Token has expired - refreshing...")
self.login(self._auth_params["login"], self._auth_params["password"])
else:
raise AuthTokenExpiredError("Token has expired - please re-login")
else:
# Create a new authorization token
self.log.info(f"No token - login user: {self._auth_params['login']}")
if self._auth_params.get("login", None) and self._auth_params.get("password", None):
self.login(self._auth_params["login"], self._auth_params["password"])
else:
raise ClientError("Missing login or password")

return f(self, *args)
def validate_auth(self):
"""Validate that client has valid auth token or can be logged in."""

return wrapper
if self._auth_session:
# Refresh auth token if it expired or will expire very soon
delta = self._auth_session["expire"] - datetime.now(timezone.utc)
if delta.total_seconds() < 5:
self.log.info("Token has expired - refreshing...")
if self._auth_params.get("login", None) and self._auth_params.get("password", None):
self.log.info("Token has expired - refreshing...")
self.login(self._auth_params["login"], self._auth_params["password"])
else:
raise AuthTokenExpiredError("Token has expired - please re-login")
else:
# Create a new authorization token
self.log.info(f"No token - login user: {self._auth_params.get('login', None)}")
if self._auth_params.get("login", None) and self._auth_params.get("password", None):
self.login(self._auth_params["login"], self._auth_params["password"])
else:
raise ClientError("Missing login or password")

@_check_token
def _do_request(self, request):
def _do_request(self, request, validate_auth=True):
"""General server request method."""
if validate_auth:
self.validate_auth()

if self._auth_session:
request.add_header("Authorization", self._auth_session["token"])
request.add_header("User-Agent", self.user_agent_info())
Expand Down Expand Up @@ -263,31 +259,31 @@ def _do_request(self, request):
# e.g. when DNS resolution fails (no internet connection?)
raise ClientError("Error requesting " + request.full_url + ": " + str(e))

def get(self, path, data=None, headers={}):
def get(self, path, data=None, headers={}, validate_auth=True):
url = urllib.parse.urljoin(self.url, urllib.parse.quote(path))
if data:
url += "?" + urllib.parse.urlencode(data)
request = urllib.request.Request(url, headers=headers)
return self._do_request(request)
return self._do_request(request, validate_auth=validate_auth)

def post(self, path, data=None, headers={}):
def post(self, path, data=None, headers={}, validate_auth=True):
url = urllib.parse.urljoin(self.url, urllib.parse.quote(path))
if headers.get("Content-Type", None) == "application/json":
data = json.dumps(data, cls=DateTimeEncoder).encode("utf-8")
request = urllib.request.Request(url, data, headers, method="POST")
return self._do_request(request)
return self._do_request(request, validate_auth=validate_auth)

def patch(self, path, data=None, headers={}):
def patch(self, path, data=None, headers={}, validate_auth=True):
url = urllib.parse.urljoin(self.url, urllib.parse.quote(path))
if headers.get("Content-Type", None) == "application/json":
data = json.dumps(data, cls=DateTimeEncoder).encode("utf-8")
request = urllib.request.Request(url, data, headers, method="PATCH")
return self._do_request(request)
return self._do_request(request, validate_auth=validate_auth)

def delete(self, path):
def delete(self, path, validate_auth=True):
url = urllib.parse.urljoin(self.url, urllib.parse.quote(path))
request = urllib.request.Request(url, method="DELETE")
return self._do_request(request)
return self._do_request(request, validate_auth=validate_auth)

def login(self, login, password):
"""
Expand All @@ -303,26 +299,16 @@ def login(self, login, password):
self._auth_session = None
self.log.info(f"Going to log in user {login}")
try:
self._auth_params = params
url = urllib.parse.urljoin(self.url, urllib.parse.quote("/v1/auth/login"))
data = json.dumps(self._auth_params, cls=DateTimeEncoder).encode("utf-8")
request = urllib.request.Request(url, data, {"Content-Type": "application/json"}, method="POST")
request.add_header("User-Agent", self.user_agent_info())
resp = self.opener.open(request)
resp = self.post(
"/v1/auth/login", data=params, headers={"Content-Type": "application/json"}, validate_auth=False
)
data = json.load(resp)
session = data["session"]
except urllib.error.HTTPError as e:
if e.headers.get("Content-Type", "") == "application/problem+json":
info = json.load(e)
self.log.info(f"Login problem: {info.get('detail')}")
raise LoginError(info.get("detail"))
self.log.info(f"Login problem: {e.read().decode('utf-8')}")
raise LoginError(e.read().decode("utf-8"))
except urllib.error.URLError as e:
# e.g. when DNS resolution fails (no internet connection?)
raise ClientError("failure reason: " + str(e.reason))
except ClientError as e:
self.log.info(f"Login problem: {e.detail}")
raise LoginError(e.detail)
self._auth_session = {
"token": "Bearer %s" % session["token"],
"token": f"Bearer {session['token']}",
"expire": dateutil.parser.parse(session["expire"]),
}
self._user_info = {"username": data["username"]}
Expand Down Expand Up @@ -367,7 +353,7 @@ def server_type(self):
"""
if not self._server_type:
try:
resp = self.get("/config")
resp = self.get("/config", validate_auth=False)
config = json.load(resp)
if config["server_type"] == "ce":
self._server_type = ServerType.CE
Expand All @@ -389,7 +375,7 @@ def server_version(self):
"""
if self._server_version is None:
try:
resp = self.get("/config")
resp = self.get("/config", validate_auth=False)
config = json.load(resp)
self._server_version = config["version"]
except (ClientError, KeyError):
Expand Down Expand Up @@ -1386,7 +1372,7 @@ def remove_project_collaborator(self, project_id: str, user_id: int):

def server_config(self) -> dict:
"""Get server configuration as dictionary."""
response = self.get("/config")
response = self.get("/config", validate_auth=False)
return json.load(response)

def send_logs(
Expand Down
64 changes: 61 additions & 3 deletions mergin/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tempfile
import subprocess
import shutil
from datetime import datetime, timedelta, date
from datetime import datetime, timedelta, date, timezone
import pytest
import pytz
import sqlite3
Expand All @@ -14,6 +14,7 @@
from .. import InvalidProject
from ..client import (
MerginClient,
AuthTokenExpiredError,
ClientError,
MerginProject,
LoginError,
Expand Down Expand Up @@ -2888,8 +2889,7 @@ def test_mc_without_login():
with pytest.raises(ClientError) as e:
mc.workspaces_list()

assert e.value.http_error == 401
assert e.value.detail == '"Authentication information is missing or invalid."\n'
assert e.value.detail == "Missing login or password"


def test_do_request_error_handling(mc: MerginClient):
Expand All @@ -2911,3 +2911,61 @@ def test_do_request_error_handling(mc: MerginClient):

assert e.value.http_error == 400
assert "Passwords must be at least 8 characters long." in e.value.detail


def test_validate_auth(mc: MerginClient):
"""Test validate authentication under different scenarios."""

# ----- Client without authentication -----
mc_not_auth = MerginClient(SERVER_URL)

with pytest.raises(ClientError) as e:
mc_not_auth.validate_auth()

assert e.value.detail == "Missing login or password"

# ----- Client with token -----
# create a client with valid auth token based on other MerginClient instance, but not with username/password
mc_auth_token = MerginClient(SERVER_URL, auth_token=mc._auth_session["token"])

# this should pass and not raise an error
mc_auth_token.validate_auth()

# manually set expire date to the past to simulate expired token
mc_auth_token._auth_session["expire"] = datetime.now(timezone.utc) - timedelta(days=1)

# check that this raises an error
with pytest.raises(AuthTokenExpiredError):
mc_auth_token.validate_auth()

# ----- Client with token and username/password -----
# create a client with valid auth token based on other MerginClient instance with username/password that allows relogin if the token is expired
mc_auth_token_login = MerginClient(
SERVER_URL, auth_token=mc._auth_session["token"], login=API_USER, password=USER_PWD
)

# this should pass and not raise an error
mc_auth_token_login.validate_auth()

# manually set expire date to the past to simulate expired token
mc_auth_token_login._auth_session["expire"] = datetime.now(timezone.utc) - timedelta(days=1)

# this should pass and not raise an error, as the client is able to re-login
mc_auth_token_login.validate_auth()

# ----- Client with token and username/WRONG password -----
# create a client with valid auth token based on other MerginClient instance with username and WRONG password
# that does NOT allow relogin if the token is expired
mc_auth_token_login_wrong_password = MerginClient(
SERVER_URL, auth_token=mc._auth_session["token"], login=API_USER, password="WRONG_PASSWORD"
)

# this should pass and not raise an error
mc_auth_token_login_wrong_password.validate_auth()

# manually set expire date to the past to simulate expired token
mc_auth_token_login_wrong_password._auth_session["expire"] = datetime.now(timezone.utc) - timedelta(days=1)

# this should pass and not raise an error, as the client is able to re-login
with pytest.raises(LoginError):
mc_auth_token_login_wrong_password.validate_auth()