diff --git a/packages/core/src/util/token.spec.ts b/packages/core/src/util/token.spec.ts index a276467931..49aa2c723c 100644 --- a/packages/core/src/util/token.spec.ts +++ b/packages/core/src/util/token.spec.ts @@ -20,18 +20,26 @@ // import { jest, it, describe, expect } from "@jest/globals"; -import type { JWTPayload, KeyLike } from "jose"; +import type * as Jose from "jose"; import { SignJWT, generateKeyPair, exportJWK } from "jose"; import { getWebidFromTokenPayload } from "./token"; +jest.mock("jose", () => { + const actualJose = jest.requireActual("jose") as typeof Jose; + return { + ...actualJose, + createRemoteJWKSet: jest.fn(), + }; +}); + describe("getWebidFromTokenPayload", () => { // Singleton keys generated on the first call to mockJwk - let publicKey: KeyLike | undefined; - let privateKey: KeyLike | undefined; + let publicKey: Jose.KeyLike | undefined; + let privateKey: Jose.KeyLike | undefined; const mockJwk = async (): Promise<{ - publicKey: KeyLike; - privateKey: KeyLike; + publicKey: Jose.KeyLike; + privateKey: Jose.KeyLike; }> => { if (typeof publicKey === "undefined" || typeof privateKey === "undefined") { const generatedPair = await generateKeyPair("ES256"); @@ -53,10 +61,10 @@ describe("getWebidFromTokenPayload", () => { }; const mockJwt = async ( - claims: JWTPayload, + claims: Jose.JWTPayload, issuer: string, audience: string, - signingKey?: KeyLike, + signingKey?: Jose.KeyLike, ): Promise => { return new SignJWT(claims) .setProtectedHeader({ alg: "ES256" }) @@ -67,40 +75,13 @@ describe("getWebidFromTokenPayload", () => { .sign(signingKey ?? (await mockJwk()).privateKey); }; - const mockFetch = ( - payload: string, - statusCode: number, - statusText?: string, - ): void => { - jest - .spyOn(globalThis, "fetch") - .mockResolvedValueOnce( - new Response(payload, { status: statusCode, statusText }), - ); - }; - - it("throws if the JWKS cannot be fetched", async () => { - mockFetch("", 404, "Not Found"); - const jwt = await mockJwt( - { someClaim: true }, - "https://some.issuer", - "https://some.clientId", - ); - await expect( - getWebidFromTokenPayload( - jwt, - "https://some.jwks", - "https://some.issuer", - "https://some.clientId", - ), - ).rejects.toThrow( - "Could not fetch JWKS for [https://some.issuer] at [https://some.jwks]: 404 Not Found", + it("throws if the JWKS retrieval fails", async () => { + const mockJose = jest.requireMock("jose") as jest.Mocked; + mockJose.createRemoteJWKSet.mockReturnValue( + jest + .fn>() + .mockRejectedValue("Maformed JWKS"), ); - }); - - it("throws if the JWKS is malformed", async () => { - // Invalid JSON. - mockFetch("", 200); const jwt = await mockJwt( { someClaim: true }, "https://some.issuer", @@ -113,13 +94,16 @@ describe("getWebidFromTokenPayload", () => { "https://some.issuer", "https://some.clientId", ), - ).rejects.toThrow( - "Malformed JWKS for [https://some.issuer] at [https://some.jwks]:", - ); + ).rejects.toThrow("Token verification failed"); }); it("throws if the ID token signature verification fails", async () => { - mockFetch(await mockJwks(), 200); + const mockJose = jest.requireMock("jose") as jest.Mocked; + mockJose.createRemoteJWKSet.mockReturnValue( + jest + .fn>() + .mockResolvedValue((await mockJwk()).publicKey), + ); const { privateKey: anotherKey } = await generateKeyPair("ES256"); // Sign the returned JWT with a private key unrelated to the public key in the JWKS const jwt = await mockJwt( @@ -141,7 +125,12 @@ describe("getWebidFromTokenPayload", () => { }); it("throws if the ID token issuer verification fails", async () => { - mockFetch(await mockJwks(), 200); + const mockJose = jest.requireMock("jose") as jest.Mocked; + mockJose.createRemoteJWKSet.mockReturnValue( + jest + .fn>() + .mockResolvedValue((await mockJwk()).publicKey), + ); const jwt = await mockJwt( { someClaim: true }, "https://some.other.issuer", @@ -160,7 +149,12 @@ describe("getWebidFromTokenPayload", () => { }); it("throws if the ID token audience verification fails", async () => { - mockFetch(await mockJwks(), 200); + const mockJose = jest.requireMock("jose") as jest.Mocked; + mockJose.createRemoteJWKSet.mockReturnValue( + jest + .fn>() + .mockResolvedValue((await mockJwk()).publicKey), + ); const jwt = await mockJwt( { someClaim: true }, "https://some.issuer", @@ -179,7 +173,12 @@ describe("getWebidFromTokenPayload", () => { }); it("throws if the 'webid' and the 'sub' claims are missing", async () => { - mockFetch(await mockJwks(), 200); + const mockJose = jest.requireMock("jose") as jest.Mocked; + mockJose.createRemoteJWKSet.mockReturnValue( + jest + .fn>() + .mockResolvedValue((await mockJwk()).publicKey), + ); const jwt = await mockJwt( { someClaim: true }, "https://some.issuer", @@ -196,12 +195,18 @@ describe("getWebidFromTokenPayload", () => { }); it("throws if the 'webid' claims is missing and the 'sub' claim is not an IRI", async () => { - mockFetch(await mockJwks(), 200); + const mockJose = jest.requireMock("jose") as jest.Mocked; + mockJose.createRemoteJWKSet.mockReturnValue( + jest + .fn>() + .mockResolvedValue((await mockJwk()).publicKey), + ); const jwt = await mockJwt( { sub: "some user ID" }, "https://some.issuer", "https://some.clientId", ); + await expect( getWebidFromTokenPayload( jwt, @@ -215,7 +220,12 @@ describe("getWebidFromTokenPayload", () => { }); it("returns the WebID it the 'webid' claim exists", async () => { - mockFetch(await mockJwks(), 200); + const mockJose = jest.requireMock("jose") as jest.Mocked; + mockJose.createRemoteJWKSet.mockReturnValue( + jest + .fn>() + .mockResolvedValue((await mockJwk()).publicKey), + ); const jwt = await mockJwt( { webid: "https://some.webid#me" }, "https://some.issuer", @@ -232,7 +242,12 @@ describe("getWebidFromTokenPayload", () => { }); it("returns the WebID it the 'sub' claim exists and it is IRI-like", async () => { - mockFetch(await mockJwks(), 200); + const mockJose = jest.requireMock("jose") as jest.Mocked; + mockJose.createRemoteJWKSet.mockReturnValue( + jest + .fn>() + .mockResolvedValue((await mockJwk()).publicKey), + ); const jwt = await mockJwt( { sub: "https://some.webid#me" }, "https://some.issuer", diff --git a/packages/core/src/util/token.ts b/packages/core/src/util/token.ts index 4b9d600ab6..deed975868 100644 --- a/packages/core/src/util/token.ts +++ b/packages/core/src/util/token.ts @@ -20,40 +20,11 @@ // // eslint-disable-next-line no-shadow -import type { JWK, JWTPayload } from "jose"; -import { jwtVerify, importJWK } from "jose"; +import type { JWTPayload } from "jose"; +import { jwtVerify, createRemoteJWKSet } from "jose"; -type WithMessage = { message: string }; type WithStack = { stack: string }; -export async function fetchJwks( - jwksIri: string, - issuerIri: string, -): Promise { - // FIXME: the following line works, but the underlying network calls don't seem - // to be mocked properly by our test code. It would be nicer to replace calls to this - // function by the following line and to fix the mocks. - // const jwks = createRemoteJWKSet(new URL(jwksIri)); - const jwksResponse = await fetch(jwksIri); - if (jwksResponse.status !== 200) { - throw new Error( - `Could not fetch JWKS for [${issuerIri}] at [${jwksIri}]: ${jwksResponse.status} ${jwksResponse.statusText}`, - ); - } - // The JWKS should only contain the current key for the issuer. - let jwk: JWK; - try { - jwk = (await jwksResponse.json()).keys[0] as JWK; - } catch (e) { - throw new Error( - `Malformed JWKS for [${issuerIri}] at [${jwksIri}]: ${ - (e as WithMessage).message - }`, - ); - } - return jwk; -} - /** * Extract a WebID from an ID token payload based on https://github.com/solid/webid-oidc-spec. * Note that this does not yet implement the user endpoint lookup, and only checks @@ -69,12 +40,11 @@ export async function getWebidFromTokenPayload( issuerIri: string, clientId: string, ): Promise { - const jwk = await fetchJwks(jwksIri, issuerIri); let payload: JWTPayload; try { const { payload: verifiedPayload } = await jwtVerify( idToken, - await importJWK(jwk), + createRemoteJWKSet(new URL(jwksIri)), { issuer: issuerIri, audience: clientId,