Skip to content

Commit

Permalink
Improve JWKS support
Browse files Browse the repository at this point in the history
Remove assumption that the remote JWKS contains a single JWK.
  • Loading branch information
NSeydoux committed Mar 1, 2024
1 parent 1c1e96c commit 2c3df07
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 83 deletions.
115 changes: 65 additions & 50 deletions packages/core/src/util/token.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,26 @@
//

import { jest, it, describe, expect } from "@jest/globals";
import type { JWTPayload, KeyLike } from "jose";
import type * as Jose from "jose";
import { SignJWT, generateKeyPair, exportJWK } from "jose";
import { getWebidFromTokenPayload } from "./token";

jest.mock("jose", () => {
const actualJose = jest.requireActual("jose") as typeof Jose;
return {
...actualJose,
createRemoteJWKSet: jest.fn(),
};
});

describe("getWebidFromTokenPayload", () => {
// Singleton keys generated on the first call to mockJwk
let publicKey: KeyLike | undefined;
let privateKey: KeyLike | undefined;
let publicKey: Jose.KeyLike | undefined;
let privateKey: Jose.KeyLike | undefined;

const mockJwk = async (): Promise<{
publicKey: KeyLike;
privateKey: KeyLike;
publicKey: Jose.KeyLike;
privateKey: Jose.KeyLike;
}> => {
if (typeof publicKey === "undefined" || typeof privateKey === "undefined") {
const generatedPair = await generateKeyPair("ES256");
Expand All @@ -53,10 +61,10 @@ describe("getWebidFromTokenPayload", () => {
};

const mockJwt = async (
claims: JWTPayload,
claims: Jose.JWTPayload,
issuer: string,
audience: string,
signingKey?: KeyLike,
signingKey?: Jose.KeyLike,
): Promise<string> => {
return new SignJWT(claims)
.setProtectedHeader({ alg: "ES256" })
Expand All @@ -67,40 +75,13 @@ describe("getWebidFromTokenPayload", () => {
.sign(signingKey ?? (await mockJwk()).privateKey);
};

const mockFetch = (
payload: string,
statusCode: number,
statusText?: string,
): void => {
jest
.spyOn(globalThis, "fetch")
.mockResolvedValueOnce(
new Response(payload, { status: statusCode, statusText }),
);
};

it("throws if the JWKS cannot be fetched", async () => {
mockFetch("", 404, "Not Found");
const jwt = await mockJwt(
{ someClaim: true },
"https://some.issuer",
"https://some.clientId",
);
await expect(
getWebidFromTokenPayload(
jwt,
"https://some.jwks",
"https://some.issuer",
"https://some.clientId",
),
).rejects.toThrow(
"Could not fetch JWKS for [https://some.issuer] at [https://some.jwks]: 404 Not Found",
it("throws if the JWKS retrieval fails", async () => {
const mockJose = jest.requireMock("jose") as jest.Mocked<typeof Jose>;
mockJose.createRemoteJWKSet.mockReturnValue(
jest
.fn<ReturnType<(typeof Jose)["createRemoteJWKSet"]>>()
.mockRejectedValue("Maformed JWKS"),
);
});

it("throws if the JWKS is malformed", async () => {
// Invalid JSON.
mockFetch("", 200);
const jwt = await mockJwt(
{ someClaim: true },
"https://some.issuer",
Expand All @@ -113,13 +94,16 @@ describe("getWebidFromTokenPayload", () => {
"https://some.issuer",
"https://some.clientId",
),
).rejects.toThrow(
"Malformed JWKS for [https://some.issuer] at [https://some.jwks]:",
);
).rejects.toThrow("Token verification failed");
});

it("throws if the ID token signature verification fails", async () => {
mockFetch(await mockJwks(), 200);
const mockJose = jest.requireMock("jose") as jest.Mocked<typeof Jose>;
mockJose.createRemoteJWKSet.mockReturnValue(
jest
.fn<ReturnType<(typeof Jose)["createRemoteJWKSet"]>>()
.mockResolvedValue((await mockJwk()).publicKey),
);
const { privateKey: anotherKey } = await generateKeyPair("ES256");
// Sign the returned JWT with a private key unrelated to the public key in the JWKS
const jwt = await mockJwt(
Expand All @@ -141,7 +125,12 @@ describe("getWebidFromTokenPayload", () => {
});

it("throws if the ID token issuer verification fails", async () => {
mockFetch(await mockJwks(), 200);
const mockJose = jest.requireMock("jose") as jest.Mocked<typeof Jose>;
mockJose.createRemoteJWKSet.mockReturnValue(
jest
.fn<ReturnType<(typeof Jose)["createRemoteJWKSet"]>>()
.mockResolvedValue((await mockJwk()).publicKey),
);
const jwt = await mockJwt(
{ someClaim: true },
"https://some.other.issuer",
Expand All @@ -160,7 +149,12 @@ describe("getWebidFromTokenPayload", () => {
});

it("throws if the ID token audience verification fails", async () => {
mockFetch(await mockJwks(), 200);
const mockJose = jest.requireMock("jose") as jest.Mocked<typeof Jose>;
mockJose.createRemoteJWKSet.mockReturnValue(
jest
.fn<ReturnType<(typeof Jose)["createRemoteJWKSet"]>>()
.mockResolvedValue((await mockJwk()).publicKey),
);
const jwt = await mockJwt(
{ someClaim: true },
"https://some.issuer",
Expand All @@ -179,7 +173,12 @@ describe("getWebidFromTokenPayload", () => {
});

it("throws if the 'webid' and the 'sub' claims are missing", async () => {
mockFetch(await mockJwks(), 200);
const mockJose = jest.requireMock("jose") as jest.Mocked<typeof Jose>;
mockJose.createRemoteJWKSet.mockReturnValue(
jest
.fn<ReturnType<(typeof Jose)["createRemoteJWKSet"]>>()
.mockResolvedValue((await mockJwk()).publicKey),
);
const jwt = await mockJwt(
{ someClaim: true },
"https://some.issuer",
Expand All @@ -196,12 +195,18 @@ describe("getWebidFromTokenPayload", () => {
});

it("throws if the 'webid' claims is missing and the 'sub' claim is not an IRI", async () => {
mockFetch(await mockJwks(), 200);
const mockJose = jest.requireMock("jose") as jest.Mocked<typeof Jose>;
mockJose.createRemoteJWKSet.mockReturnValue(
jest
.fn<ReturnType<(typeof Jose)["createRemoteJWKSet"]>>()
.mockResolvedValue((await mockJwk()).publicKey),
);
const jwt = await mockJwt(
{ sub: "some user ID" },
"https://some.issuer",
"https://some.clientId",
);

await expect(
getWebidFromTokenPayload(
jwt,
Expand All @@ -215,7 +220,12 @@ describe("getWebidFromTokenPayload", () => {
});

it("returns the WebID it the 'webid' claim exists", async () => {
mockFetch(await mockJwks(), 200);
const mockJose = jest.requireMock("jose") as jest.Mocked<typeof Jose>;
mockJose.createRemoteJWKSet.mockReturnValue(
jest
.fn<ReturnType<(typeof Jose)["createRemoteJWKSet"]>>()
.mockResolvedValue((await mockJwk()).publicKey),
);
const jwt = await mockJwt(
{ webid: "https://some.webid#me" },
"https://some.issuer",
Expand All @@ -232,7 +242,12 @@ describe("getWebidFromTokenPayload", () => {
});

it("returns the WebID it the 'sub' claim exists and it is IRI-like", async () => {
mockFetch(await mockJwks(), 200);
const mockJose = jest.requireMock("jose") as jest.Mocked<typeof Jose>;
mockJose.createRemoteJWKSet.mockReturnValue(
jest
.fn<ReturnType<(typeof Jose)["createRemoteJWKSet"]>>()
.mockResolvedValue((await mockJwk()).publicKey),
);
const jwt = await mockJwt(
{ sub: "https://some.webid#me" },
"https://some.issuer",
Expand Down
36 changes: 3 additions & 33 deletions packages/core/src/util/token.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,40 +20,11 @@
//

// eslint-disable-next-line no-shadow
import type { JWK, JWTPayload } from "jose";
import { jwtVerify, importJWK } from "jose";
import type { JWTPayload } from "jose";
import { jwtVerify, createRemoteJWKSet } from "jose";

type WithMessage = { message: string };
type WithStack = { stack: string };

export async function fetchJwks(
jwksIri: string,
issuerIri: string,
): Promise<JWK> {
// FIXME: the following line works, but the underlying network calls don't seem
// to be mocked properly by our test code. It would be nicer to replace calls to this
// function by the following line and to fix the mocks.
// const jwks = createRemoteJWKSet(new URL(jwksIri));
const jwksResponse = await fetch(jwksIri);
if (jwksResponse.status !== 200) {
throw new Error(
`Could not fetch JWKS for [${issuerIri}] at [${jwksIri}]: ${jwksResponse.status} ${jwksResponse.statusText}`,
);
}
// The JWKS should only contain the current key for the issuer.
let jwk: JWK;
try {
jwk = (await jwksResponse.json()).keys[0] as JWK;
} catch (e) {
throw new Error(
`Malformed JWKS for [${issuerIri}] at [${jwksIri}]: ${
(e as WithMessage).message
}`,
);
}
return jwk;
}

/**
* Extract a WebID from an ID token payload based on https://github.com/solid/webid-oidc-spec.
* Note that this does not yet implement the user endpoint lookup, and only checks
Expand All @@ -69,12 +40,11 @@ export async function getWebidFromTokenPayload(
issuerIri: string,
clientId: string,
): Promise<string> {
const jwk = await fetchJwks(jwksIri, issuerIri);
let payload: JWTPayload;
try {
const { payload: verifiedPayload } = await jwtVerify(
idToken,
await importJWK(jwk),
createRemoteJWKSet(new URL(jwksIri)),
{
issuer: issuerIri,
audience: clientId,
Expand Down

0 comments on commit 2c3df07

Please sign in to comment.