How to add a custom authenticator in dask-gateway?

Hello all,

I’ve a JWT authentication mechanism that I would like to use to authenticate dask-gateway. Where should the code for that live? Can it go in the helm configuration? Should I fork dask-gateway?

2 Likes

Interested in this as well!

Maybe you could contribute it to dask-gateway?

1 Like

Maybe you could contribute it to dask-gateway ?

Could be.

What I’m doing currently is having something similar to the BasicAuth mechanism. That is the user does not have to pass credentials (user, passwd) to auth mechanism, but rather the JWT Token.

In the dask-gateway-server side, I have code fwd the headers to valiation service…

Does that sound something that could go in dask-gateway?

Sure that sounds great

Solution

To authenticate a user using JWT, the user must send the token in the header and using JSON Web Key Set we verify the token. So, the user doesn’t provide credentials, the user must be able to get a token.

Deployment

I use the helm chart dask-gateway and I modify the “values” of the helm chart.

RELEASE=dask-gateway
NAMESPACE=dask-gateway

helm upgrade $RELEASE dask-gateway \
    --repo=https://helm.dask.org \
    --install \
    --namespace $NAMESPACE \
    --values dask-gateway-helm-config-2022.10.0.yaml

dask-gateway-image with pyjwt

I created my own image of dask-gateway-server and added pyjwt as a dependency.

I created the image using the pypi distributed package of dask-gateway-server and the dependencies of listed in the source and then I configure the helm chart to use that image.

  # The image to use for the dask-gateway-server pod (api pod)
  image:
    name: docker.image.registry.of.some.sort/dask-gateway-server   # ghcr.io/dask/dask-gateway-server
    tag: "mytag"
    pullPolicy: Always 

Code

We use PyJWKClient to connect to download the JSON Web Key Set and verify the token getting the claims. WARNING HACKITY HACK AHEAD. I place the code in helm chart values yaml! Why? Because I would have to create a lib and added in the docker image. But it works.

  # Any extra configuration code to append to the generated `dask_gateway_config.py`
  # file. Can be either a single code-block, or a map of key -> code-block
  # (code-blocks are run in alphabetical order by key, the key value itself is
  # meaningless). The map version is useful as it supports merging multiple
  # `values.yaml` files, but is unnecessary in other cases.
  extraConfig: |-
    import functools
    import contextvars
    
    import asyncio
    import aiohttp
    import jwt
    
    from aiohttp import web
    
    from dask_gateway_server.auth import SimpleAuthenticator, unauthorized, User
    from dask_gateway_server.options import Options, Integer, Float
    
    class JwtAuthenticator(SimpleAuthenticator):
    
        jwks_url = "https://server/where/you/get/jwks"  # json
    
        async def setup(self, app):
            self.session = aiohttp.ClientSession()
            self._jwks_client = jwt.PyJWKClient(self.jwks_url)
    
        # Maybe premature optimization... get_signing_key_from_jwt is blocking call,
        # so wrap it around a thread, maybe.
        async def _wrapped_get_signing_key_from_jwt(self, token):
            # no asyncio.to_thread in 3.8 :(
            loop = asyncio.get_event_loop()
            ctx = contextvars.copy_context()
            func = self._jwks_client.get_signing_key_from_jwt
            func_call = functools.partial(ctx.run, func, token)
            signing_key = await loop.run_in_executor(None, func_call)
            return signing_key
    
        def validate(self, signing_key, token):
            return jwt.decode(token, signing_key.key, algorithms=["RS256"])
    
        async def cleanup(self):
            if hasattr(self, "session"):
                await self.session.close()
    
        async def authenticate(self, request):
            auth_headers = request.headers.get("Authorization")
            if not auth_headers:
                raise web.HTTPUnauthorized(reason="No JWT in Headers.")

            # not happy with this, but :shrug:
            try:
                idx = auth_headers.index("Bearer ")
            except ValueError:
                raise web.HTTPUnauthorized(reason="No 'Bearer' in Authorization header.")
            token = auth_headers[idx + 7 :]  # len("Bearer ") == 7
    
            signing_key = await self._wrapped_get_signing_key_from_jwt(token)
            try:
                data = self.validate(signing_key, token)
            except jwt.exceptions.ExpiredSignatureError as e:
                self.log.debug("JWT: Expired Key")
                raise unauthorized("expired jwt")
            
            if data:
                return User(
                    data["additional_info"]["username"],
                    groups=[],
                    admin=False,
                )
            else:
                self.log.debug("JWT: No data in token validation.")
                raise unauthorized("jwt")

    def options_handler(options):
        return {
            "worker_cores": options.worker_cores,
            "worker_memory": int(options.worker_memory * 2 ** 30),
        }
    c.DaskGateway.authenticator_class = JwtAuthenticator
    c.Backend.cluster_options = Options(
        Integer("worker_cores", default=1, min=1, max=4, label="Worker Cores"),
        Float("worker_memory", default=1, min=1, max=8, label="Worker Memory (GiB)"),
        handler=options_handler,
    )

Client side

from dask_gateway.auth import GatewayAuth

class JwtAuth(GatewayAuth):
    def __init__(self, token):
        self.token = token
    def pre_request(self, resp):
        data = self.token
        headers = {"Authorization": "Bearer " + data}
        return headers, None

g = Gateway(GATEWAY_ENDPOINT, auth=JwtAuth(token=token))

Further work

For this solution to be distributable:

  1. Should dask-gateway-server depend on pyjwt? Can it be more pluggable?
  2. Make the jwks_url configurable.
  3. The claims by a JWT are variable so that should be configurable to create a user.

@jacobtomlinson does this make sense to add in dask-gateway code?