Caching JWKS using Redis with Django
• 5 minute read
django, drf, auth0, jwt, redis
Table of contents
By the end of the last article about Authentication with DRF, I described an issue we ran into: our application consults the JWKS endpoint all the time to validate a JWT. The project I shared here has this problem. Django has a lovely cache abstraction, and recently, with version 4, it added support for Redis. So, to solve our problem, let's start configuring the development environment.
Redis as compose service
This is how we'll configure the compose file:
version: "3.9"
services:
redis:
image: redis:6.2.6-alpine
command:
[
"redis-server",
"--requirepass",
"this-is-your-admin-password"
]
ports:
- "6379:6379"
If the service is ready to accept connections, you can use redis-client
to test it:
docker-compose run redis redis-cli -h redis -p 6379 -a "this-is-your-admin-password"
If you type ping
, you should receive pong
:
redis:6379> PING
PONG
Learn more through its documentation.
Discovering where PyJWT calls the JWKS endpoint
This discovery part is crucial. Knowing this, we can override the method responsible for this process, including the cache method in between.
Looking at the constructor method of PyJWKClient, it does not call the JWKS endpoint. So, next, the client retrieves the key_id
during the authenticate method execution through the method get_signing_key_from_jwt
. This method eventually calls fetch_data
, and this procedure requests the JWKS endpoint. Look at its implementation:
def fetch_data(self) -> Any:
with urllib.request.urlopen(self.uri) as response:
return json.load(response)
Adding a caching layer
We can create a class extending the PyJWKClient
and override the fetch_data
method. Then, using the low-level cache API from Django, we can use the get_or_set
to call the fetch_data
only if the value isn't available in the cache. Translating this idea into code:
class CachingJWKClient(PyJWKClient):
cache_key = "MY_APP_NAME_JWKS"
cache_timeout_1_day = 60 * 60 * 24
def __init__(self, uri: str):
super().__init__(uri)
def fetch_data(self):
return cache.get_or_set(self.cache_key, super().fetch_data, timeout=self.cache_timeout_1_day)
Testing the custom PyJWKClient class
Testing what we coded is relatively straightforward. We should guarantee that upon calling the fetch_data
method from our custom class many times, the actual fetch_data
from the super is only called once:
class TestJWKSCache:
def test_should_use_cache_when_executing_fetch_data(self, mocker):
# Arrange
cache.delete("MY_APP_NAME_JWKS")
url = "https://salted-url/.well-known/jwks.json"
jwks_client = CachingJWKClient(url)
fake_data = "salt-licker"
mock_fetch_data = mocker.patch(
"cache_django.apps.core.api.authentication.authentications.PyJWKClient.fetch_data"
)
mock_fetch_data.return_value = fake_data
# Act
assert jwks_client.fetch_data() == fake_data
assert jwks_client.fetch_data() == fake_data
assert jwks_client.fetch_data() == fake_data
assert jwks_client.fetch_data() == fake_data
# Assert
assert mock_fetch_data.call_count == 1
We can guarantee the test behavior because we can patch the fetch_data
method from the superclass 🤩.
Next steps and conclusion
The Cache API from Django has many options that we didn't cover here. I strongly recommend reading its guide. By the way, from personal experience, I can tell you that caching is difficult and risky. So don't ever underestimate its great potential to make disorder in your system. Use it wisely and do tests to certify that everything will work as expected.
We are close to the point where we can talk about APIView and OpenAPI Schema. We already have a sample in the project about the former, though it's not even close to being production-ready. See you soon 🤟!
See everything we did here on GitHub.
Posted listening to Ocarina of Time Ambiance - Grottos - 10 Hours 🎶.