Hello all,
I’ve a JWT authentication mechanism that I would like to use to authenticate dask-gateway. Where should the code for that live? Can it go in the helm configuration? Should I fork dask-gateway?
Hello all,
I’ve a JWT authentication mechanism that I would like to use to authenticate dask-gateway. Where should the code for that live? Can it go in the helm configuration? Should I fork dask-gateway?
Interested in this as well!
Maybe you could contribute it to dask-gateway
?
Maybe you could contribute it to
dask-gateway
?
Could be.
What I’m doing currently is having something similar to the BasicAuth mechanism. That is the user does not have to pass credentials (user, passwd) to auth mechanism, but rather the JWT Token.
In the dask-gateway-server side, I have code fwd the headers to valiation service…
Does that sound something that could go in dask-gateway?
Sure that sounds great
To authenticate a user using JWT, the user must send the token in the header and using JSON Web Key Set we verify the token. So, the user doesn’t provide credentials, the user must be able to get a token.
I use the helm chart dask-gateway and I modify the “values” of the helm chart.
RELEASE=dask-gateway
NAMESPACE=dask-gateway
helm upgrade $RELEASE dask-gateway \
--repo=https://helm.dask.org \
--install \
--namespace $NAMESPACE \
--values dask-gateway-helm-config-2022.10.0.yaml
I created my own image of dask-gateway-server
and added pyjwt
as a dependency.
I created the image using the pypi distributed package of dask-gateway-server
and the dependencies of listed in the source and then I configure the helm chart to use that image.
# The image to use for the dask-gateway-server pod (api pod)
image:
name: docker.image.registry.of.some.sort/dask-gateway-server # ghcr.io/dask/dask-gateway-server
tag: "mytag"
pullPolicy: Always
We use PyJWKClient to connect to download the JSON Web Key Set and verify the token getting the claims. WARNING HACKITY HACK AHEAD. I place the code in helm chart values yaml! Why? Because I would have to create a lib and added in the docker image. But it works.
# Any extra configuration code to append to the generated `dask_gateway_config.py`
# file. Can be either a single code-block, or a map of key -> code-block
# (code-blocks are run in alphabetical order by key, the key value itself is
# meaningless). The map version is useful as it supports merging multiple
# `values.yaml` files, but is unnecessary in other cases.
extraConfig: |-
import functools
import contextvars
import asyncio
import aiohttp
import jwt
from aiohttp import web
from dask_gateway_server.auth import SimpleAuthenticator, unauthorized, User
from dask_gateway_server.options import Options, Integer, Float
class JwtAuthenticator(SimpleAuthenticator):
jwks_url = "https://server/where/you/get/jwks" # json
async def setup(self, app):
self.session = aiohttp.ClientSession()
self._jwks_client = jwt.PyJWKClient(self.jwks_url)
# Maybe premature optimization... get_signing_key_from_jwt is blocking call,
# so wrap it around a thread, maybe.
async def _wrapped_get_signing_key_from_jwt(self, token):
# no asyncio.to_thread in 3.8 :(
loop = asyncio.get_event_loop()
ctx = contextvars.copy_context()
func = self._jwks_client.get_signing_key_from_jwt
func_call = functools.partial(ctx.run, func, token)
signing_key = await loop.run_in_executor(None, func_call)
return signing_key
def validate(self, signing_key, token):
return jwt.decode(token, signing_key.key, algorithms=["RS256"])
async def cleanup(self):
if hasattr(self, "session"):
await self.session.close()
async def authenticate(self, request):
auth_headers = request.headers.get("Authorization")
if not auth_headers:
raise web.HTTPUnauthorized(reason="No JWT in Headers.")
# not happy with this, but :shrug:
try:
idx = auth_headers.index("Bearer ")
except ValueError:
raise web.HTTPUnauthorized(reason="No 'Bearer' in Authorization header.")
token = auth_headers[idx + 7 :] # len("Bearer ") == 7
signing_key = await self._wrapped_get_signing_key_from_jwt(token)
try:
data = self.validate(signing_key, token)
except jwt.exceptions.ExpiredSignatureError as e:
self.log.debug("JWT: Expired Key")
raise unauthorized("expired jwt")
if data:
return User(
data["additional_info"]["username"],
groups=[],
admin=False,
)
else:
self.log.debug("JWT: No data in token validation.")
raise unauthorized("jwt")
def options_handler(options):
return {
"worker_cores": options.worker_cores,
"worker_memory": int(options.worker_memory * 2 ** 30),
}
c.DaskGateway.authenticator_class = JwtAuthenticator
c.Backend.cluster_options = Options(
Integer("worker_cores", default=1, min=1, max=4, label="Worker Cores"),
Float("worker_memory", default=1, min=1, max=8, label="Worker Memory (GiB)"),
handler=options_handler,
)
from dask_gateway.auth import GatewayAuth
class JwtAuth(GatewayAuth):
def __init__(self, token):
self.token = token
def pre_request(self, resp):
data = self.token
headers = {"Authorization": "Bearer " + data}
return headers, None
g = Gateway(GATEWAY_ENDPOINT, auth=JwtAuth(token=token))
For this solution to be distributable:
jwks_url
configurable.@jacobtomlinson does this make sense to add in dask-gateway code?