diff --git a/README.md b/README.md index 46bb6401..5fd71548 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Treat this like [tenant-shard-db](https://github.com/elloloop/tenant-shard-db): - **Passkeys (WebAuthn)** registration and login - **TOTP (2FA)** setup, verify, recovery codes - **QR cross-device login** -- **OAuth login** (Google, Microsoft — server consumes pre-verified ID tokens from frontend SDKs) +- **OAuth login** (Google, Microsoft, GitHub, Apple — server-owned authorization-code exchange) - **Sessions** with revoke and sign-out-everywhere - **JWT issuance** with key rotation, plus `/.well-known/jwks.json` for downstream services - **User and Group CRUD**, group membership data diff --git a/docs-site/src/pages/docs/auth/oauth.astro b/docs-site/src/pages/docs/auth/oauth.astro index 7948a699..ece52cd2 100644 --- a/docs-site/src/pages/docs/auth/oauth.astro +++ b/docs-site/src/pages/docs/auth/oauth.astro @@ -6,7 +6,7 @@ import { Code } from "astro-expressive-code/components";

OAuth

- Identity supports Google, Microsoft, and GitHub as upstream OAuth + Identity supports Google, Microsoft, GitHub, and Apple as upstream OAuth providers. The flow is server-owned: the client calls BeginOAuthLogin, redirects the browser to the returned provider URL, then completes the callback with OAuthLogin. @@ -25,16 +25,27 @@ import { Code } from "astro-expressive-code/components";

Configuration

Configure each provider you intend to support:

- + + +
+ Apple Configuration. Apple requires a Team ID, a Services ID (used as the Client ID), a Key ID, and an ECDSA private key (PEM or Base64). + Unlike other providers, Apple sends this headless flow's callback via POST. Configure an application-owned HTTPS callback that accepts code, state, and the optional user payload, then forwards them to OAuthLogin. Configure https://your-api.com/oauth/callback/apple only when using the hosted flow below. +

Start RPC

Response

@@ -134,8 +146,8 @@ message BeginOAuthLoginResponse { redirects the browser to it.
  • - GET /oauth/callback/{`{provider}`} — receives the - provider's authorization code, exchanges it through the same + GET/POST /oauth/callback/{`{provider}`} — receives the + provider's authorization code (via URL query or POST form body), exchanges it through the same OAuthLogin internals, mints an opaque single-use code (OAuthOneTimeCode, 60s TTL), and 302-redirects to return_to?code=<otc>. @@ -152,12 +164,12 @@ message BeginOAuthLoginResponse {

    Enabling the hosted flow

    The hosted flow is off by default. Enable it by - setting an allowlist of return_to origins / prefixes: + setting an allowlist of return_to origins or origin-bound path prefixes:

    - A return_to must equal or be prefixed by an allowlist - entry; anything else is rejected with HTTP 400 at + A return_to must match an allowlist origin. A path entry + permits that path and its descendants; anything else is rejected with HTTP 400 at /oauth/start before any provider round-trip. Empty (the default) disables the routes entirely — they return 404 and only the headless RPCs work. The single provider-facing redirect URI registered diff --git a/docs-site/src/pages/docs/deployment/kubernetes.astro b/docs-site/src/pages/docs/deployment/kubernetes.astro index 6167b6b5..6471ade3 100644 --- a/docs-site/src/pages/docs/deployment/kubernetes.astro +++ b/docs-site/src/pages/docs/deployment/kubernetes.astro @@ -24,8 +24,8 @@ stringData: GATEWAY_JWT_KEYS: | [{"kid":"k-2026-05","private_key_pem":"-----BEGIN RSA PRIVATE KEY-----\\n...\\n-----END RSA PRIVATE KEY-----\\n","active":true}] GATEWAY_TOTP_ENCRYPTION_KEY: "" - GATEWAY_GOOGLE_CLIENT_SECRET: "..." - GATEWAY_MICROSOFT_CLIENT_SECRET: "..."`} lang="yaml" title="secret.yaml" /> + GATEWAY_OAUTH_GOOGLE_CLIENT_SECRET: "..." + GATEWAY_OAUTH_MICROSOFT_CLIENT_SECRET: "..."`} lang="yaml" title="secret.yaml" />

    ConfigMap

    + GATEWAY_OAUTH_GOOGLE_CLIENT_ID: "..." + GATEWAY_OAUTH_MICROSOFT_CLIENT_ID: "..."`} lang="yaml" title="configmap.yaml" />

    Deployment

    OAuth - - - - + + + + + + + + + +
    VariableDefaultPurpose
    GATEWAY_GOOGLE_CLIENT_ID(empty)Google OAuth client ID
    GATEWAY_GOOGLE_CLIENT_SECRET(empty)Google OAuth client secret
    GATEWAY_MICROSOFT_CLIENT_ID(empty)Microsoft / Entra ID client ID
    GATEWAY_MICROSOFT_CLIENT_SECRET(empty)Microsoft client secret
    GATEWAY_OAUTH_GOOGLE_CLIENT_ID(empty)Google OAuth client ID
    GATEWAY_OAUTH_GOOGLE_CLIENT_SECRET(empty)Google OAuth client secret
    GATEWAY_OAUTH_MICROSOFT_CLIENT_ID(empty)Microsoft / Entra ID client ID
    GATEWAY_OAUTH_MICROSOFT_CLIENT_SECRET(empty)Microsoft client secret
    GATEWAY_MICROSOFT_TENANT_ID(empty)Microsoft tenant ID (or common)
    GATEWAY_OAUTH_GITHUB_CLIENT_ID(empty)GitHub OAuth client ID
    GATEWAY_OAUTH_GITHUB_CLIENT_SECRET(empty)GitHub OAuth client secret
    GATEWAY_OAUTH_APPLE_CLIENT_ID(empty)Apple Service ID (Client ID)
    GATEWAY_OAUTH_APPLE_TEAM_ID(empty)Apple Developer Team ID
    GATEWAY_OAUTH_APPLE_KEY_ID(empty)Apple Private Key ID
    GATEWAY_OAUTH_APPLE_PRIVATE_KEY(empty)Apple Private Key (PEM format or base64)

    Password / lockout

    diff --git a/docs-site/src/pages/docs/installation/docker.astro b/docs-site/src/pages/docs/installation/docker.astro index 71c0a3e0..51b11f95 100644 --- a/docs-site/src/pages/docs/installation/docker.astro +++ b/docs-site/src/pages/docs/installation/docker.astro @@ -49,8 +49,8 @@ docker pull ghcr.io/elloloop/identity:latest`} lang="bash" /> -e GATEWAY_TOTP_ENCRYPTION_KEY="$(openssl rand -base64 32)" \\ -e GATEWAY_JWT_KEYS="$(cat keyring.json)" \\ -e GATEWAY_ALLOWED_ORIGINS="https://acme.example.com" \\ - -e GATEWAY_GOOGLE_CLIENT_ID=... \\ - -e GATEWAY_GOOGLE_CLIENT_SECRET=... \\ + -e GATEWAY_OAUTH_GOOGLE_CLIENT_ID=... \\ + -e GATEWAY_OAUTH_GOOGLE_CLIENT_SECRET=... \\ -e GATEWAY_AUTH_ALLOW_LOCAL=true \\ ghcr.io/elloloop/identity:1.2.0`} lang="bash" title="production" /> diff --git a/docs-site/src/pages/index.astro b/docs-site/src/pages/index.astro index c84c0412..592cdb1a 100644 --- a/docs-site/src/pages/index.astro +++ b/docs-site/src/pages/index.astro @@ -19,7 +19,7 @@ import { Code } from "astro-expressive-code/components";
  • Email + password — signup, login, change, recovery-email reset, lockout
  • Passkeys (WebAuthn) — registration and login flows
  • TOTP (2FA) — setup, verify, recovery codes
  • -
  • OAuth login — Google, Microsoft (server consumes pre-verified tokens from frontend SDKs)
  • +
  • OAuth login — Google, Microsoft, GitHub, Apple (server-owned authorization-code exchange)
  • QR cross-device login — initiate on a new device, approve from a logged-in one
  • JWT issuance with key rotation, plus /.well-known/jwks.json for downstream services
  • Sessions — list, revoke, sign-out-everywhere
  • diff --git a/docs/IDENTITY.md b/docs/IDENTITY.md index f7f572a0..5da4806a 100644 --- a/docs/IDENTITY.md +++ b/docs/IDENTITY.md @@ -553,7 +553,7 @@ before you ship. added to the headless state token. - **Hosted (new).** `GET /oauth/start/{provider}?return_to=` and - `GET /oauth/callback/{provider}` are plain `http.Handler` routes + `GET/POST /oauth/callback/{provider}` are plain `http.Handler` routes (not Connect RPCs — the browser is 302-redirected through them). They are thin wrappers over the same `BeginOAuthLogin` / `OAuthLogin` service internals; there is no forked exchange path. @@ -588,9 +588,10 @@ before you ship. - **`return_to` allowlist, fail-closed, disabled by default.** `GATEWAY_OAUTH_ALLOWED_RETURN_URLS` is a comma-separated list of - exact origins / URL prefixes. A `return_to` is allowed only if it - equals or is prefixed by an entry; anything else is rejected with - 400 at `/oauth/start` before any provider round-trip. **Empty + exact origins / origin-bound path prefixes. A `return_to` must match + the configured origin and, for a path entry, that path or one of its + descendants; anything else is rejected with 400 at `/oauth/start` + before any provider round-trip. **Empty disables the hosted flow entirely** — the `/oauth/*` routes are not registered (404) and only the headless RPCs work. This is the provider-facing-redirect allowlist that did not exist before: diff --git a/docs/oauth.md b/docs/oauth.md index 22d3b17d..7acffa17 100644 --- a/docs/oauth.md +++ b/docs/oauth.md @@ -1,6 +1,6 @@ # OAuth login -identity supports OAuth/OIDC sign-in with Google, Microsoft, and GitHub. +identity supports OAuth/OIDC sign-in with Google, Microsoft, GitHub, and Apple. It does the authorization-code exchange itself — the frontend is never trusted to assert the user's identity. There are two flows; a deployer can use either or both. @@ -16,14 +16,14 @@ Both run against the same provider exchangers and token-minting code. ## Enabling providers -A provider is enabled only when **both** its client id and secret are -set. Leave a provider's credentials unset to disable it. +A provider is enabled when its required credentials are set (client id and secret, or Apple's private key and IDs). Leave a provider's credentials unset to disable it. | Provider | Client ID env | Client secret env | |-----------|--------------------------------------|------------------------------------------| | Google | `GATEWAY_OAUTH_GOOGLE_CLIENT_ID` | `GATEWAY_OAUTH_GOOGLE_CLIENT_SECRET` | | Microsoft | `GATEWAY_OAUTH_MICROSOFT_CLIENT_ID` | `GATEWAY_OAUTH_MICROSOFT_CLIENT_SECRET` | | GitHub | `GATEWAY_OAUTH_GITHUB_CLIENT_ID` | `GATEWAY_OAUTH_GITHUB_CLIENT_SECRET` | +| Apple | `GATEWAY_OAUTH_APPLE_CLIENT_ID` | `GATEWAY_OAUTH_APPLE_PRIVATE_KEY` (along with TEAM_ID and KEY_ID) | Microsoft also accepts `GATEWAY_MICROSOFT_TENANT_ID` (optional). At startup identity logs the enabled providers (`oauth_providers_enabled`) @@ -40,12 +40,12 @@ URLs identity may redirect users back to: GATEWAY_OAUTH_ALLOWED_RETURN_URLS=https://app.example.com/,https://admin.example.com/auth ``` -- Comma-separated list of exact origins or URL prefixes. -- A `return_to` is accepted only if it equals an entry or begins with - one (prefix match). Validation is **fail-closed**: anything else is - rejected with `400`. +- Comma-separated list of exact origins or origin-bound path prefixes. +- A `return_to` must have the configured origin. A path entry permits that + path and its descendants, never a lookalike host or path. Validation is + **fail-closed**: anything else is rejected with `400`. - **Empty disables the hosted flow** — `GET /oauth/start/*` and - `GET /oauth/callback/*` return `404`, and only the headless RPCs work. + `GET/POST /oauth/callback/*` return `404`, and only the headless RPCs work. The active allowlist is logged at startup (`oauth_hosted_flow_enabled` / `oauth_hosted_flow_disabled`). @@ -91,7 +91,7 @@ Browser identity Provider | (user authenticates with provider) ---------------->| | 302 -> /oauth/callback/google?state=&code= | |<----------------------------------------------------- | - | GET /oauth/callback/google?state=&code= | + | GET/POST /oauth/callback/google?state=&code= | |----------------------->| | | | verify state token | | | exchange code (PKCE) | @@ -109,8 +109,8 @@ Browser identity Provider `return_to` against the allowlist, mints state + PKCE, binds `return_to` into a signed hosted state token (tamper-proof), and 302-redirects the browser to the provider. -2. **`GET /oauth/callback/{provider}`** — the single registered redirect - URI. Recovers the state token, runs the code exchange + token mint, +2. **`GET/POST /oauth/callback/{provider}`** — the single registered redirect + URI. Apple uses `POST`; others use `GET`. Recovers the state token, runs the code exchange + token mint, mints a single-use one-time code, and 302-redirects to `return_to?code=`. On any failure it returns a generic `400` (it cannot trust an unverified `return_to`) and logs server-side. @@ -138,11 +138,11 @@ caller. `{authorization_url, state, state_token, code_verifier, expires_in}`. The frontend redirects the user to `authorization_url` (which uses the frontend's own `redirect_uri`). -2. The provider redirects back to the frontend's callback page with - `?code=&state=`. -3. **`OAuthLogin{code, provider, redirect_uri, state, state_token}`** — - identity verifies the state token, exchanges the code, and returns - `{user, access_token, refresh_token, expires_in}`. +2. The provider redirects back to the frontend's callback page. Most providers + redirect via GET with `?code=&state=`. Apple redirects via HTTP POST (`form_post`) + with `code`, `state`, and an optional `user` JSON payload as form data. +3. **`OAuthLogin{code, provider, redirect_uri, state, state_token, apple_user_payload}`** — + identity supports server-owned authorization-code exchange. It does not consume pre-verified frontend SDK ID tokens. This guarantees you own the user relationship, keeps identity keys off the frontend, and enables robust refresh token flows. The headless flow has no `return_to` allowlist: the frontend supplies and owns its own `redirect_uri`, which it must register with the diff --git a/gen/go/identity/v1/identity.pb.go b/gen/go/identity/v1/identity.pb.go index c5a74040..bf178923 100644 --- a/gen/go/identity/v1/identity.pb.go +++ b/gen/go/identity/v1/identity.pb.go @@ -1898,7 +1898,7 @@ func (x *ListGroupMembersResponse) GetMembers() []*User { type BeginOAuthLoginRequest struct { state protoimpl.MessageState `protogen:"open.v1"` RedirectUri string `protobuf:"bytes,1,opt,name=redirect_uri,json=redirectUri,proto3" json:"redirect_uri,omitempty"` - Provider string `protobuf:"bytes,2,opt,name=provider,proto3" json:"provider,omitempty"` // "google", "microsoft", or "github" + Provider string `protobuf:"bytes,2,opt,name=provider,proto3" json:"provider,omitempty"` // "google", "microsoft", "github", or "apple" Tenant string `protobuf:"bytes,3,opt,name=tenant,proto3" json:"tenant,omitempty"` // Microsoft tenant ID (optional) unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -2032,16 +2032,17 @@ func (x *BeginOAuthLoginResponse) GetExpiresIn() int32 { } type OAuthLoginRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - Code string `protobuf:"bytes,1,opt,name=code,proto3" json:"code,omitempty"` - RedirectUri string `protobuf:"bytes,2,opt,name=redirect_uri,json=redirectUri,proto3" json:"redirect_uri,omitempty"` - Provider string `protobuf:"bytes,3,opt,name=provider,proto3" json:"provider,omitempty"` // "google", "microsoft", or "github" - CodeVerifier string `protobuf:"bytes,4,opt,name=code_verifier,json=codeVerifier,proto3" json:"code_verifier,omitempty"` // PKCE (optional) - Tenant string `protobuf:"bytes,5,opt,name=tenant,proto3" json:"tenant,omitempty"` // Microsoft tenant ID (optional) - State string `protobuf:"bytes,6,opt,name=state,proto3" json:"state,omitempty"` // authorization callback state (required for server-owned flow) - StateToken string `protobuf:"bytes,7,opt,name=state_token,json=stateToken,proto3" json:"state_token,omitempty"` // opaque token minted by BeginOAuthLogin - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Code string `protobuf:"bytes,1,opt,name=code,proto3" json:"code,omitempty"` + RedirectUri string `protobuf:"bytes,2,opt,name=redirect_uri,json=redirectUri,proto3" json:"redirect_uri,omitempty"` + Provider string `protobuf:"bytes,3,opt,name=provider,proto3" json:"provider,omitempty"` // "google", "microsoft", "github", or "apple" + CodeVerifier string `protobuf:"bytes,4,opt,name=code_verifier,json=codeVerifier,proto3" json:"code_verifier,omitempty"` // PKCE (optional) + Tenant string `protobuf:"bytes,5,opt,name=tenant,proto3" json:"tenant,omitempty"` // Microsoft tenant ID (optional) + State string `protobuf:"bytes,6,opt,name=state,proto3" json:"state,omitempty"` // authorization callback state (required for server-owned flow) + StateToken string `protobuf:"bytes,7,opt,name=state_token,json=stateToken,proto3" json:"state_token,omitempty"` // opaque token minted by BeginOAuthLogin + AppleUserPayload string `protobuf:"bytes,8,opt,name=apple_user_payload,json=appleUserPayload,proto3" json:"apple_user_payload,omitempty"` // one-time user payload from Apple's form_post callback + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *OAuthLoginRequest) Reset() { @@ -2123,6 +2124,13 @@ func (x *OAuthLoginRequest) GetStateToken() string { return "" } +func (x *OAuthLoginRequest) GetAppleUserPayload() string { + if x != nil { + return x.AppleUserPayload + } + return "" +} + type OAuthLoginResponse struct { state protoimpl.MessageState `protogen:"open.v1"` User *User `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` @@ -2192,7 +2200,7 @@ func (x *OAuthLoginResponse) GetExpiresIn() int32 { } // RedeemOAuthCode exchanges the single-use one-time code handed to the -// SPA by the hosted OAuth callback (GET /oauth/callback/{provider} -> +// SPA by the hosted OAuth callback (GET/POST /oauth/callback/{provider} -> // 302 return_to?code=) for a backend-issued token pair. The code // is single-use and short-lived; a replay returns CodeUnauthenticated. type RedeemOAuthCodeRequest struct { @@ -10679,7 +10687,7 @@ func (x *GetProjectConfigResponse) GetConfigJson() string { // LinkedIdentity is one connected provider identity for the authenticated user. type LinkedIdentity struct { state protoimpl.MessageState `protogen:"open.v1"` - Provider string `protobuf:"bytes,1,opt,name=provider,proto3" json:"provider,omitempty"` // "google", "microsoft", "github", ... + Provider string `protobuf:"bytes,1,opt,name=provider,proto3" json:"provider,omitempty"` // "google", "microsoft", "github", "apple", ... ProviderUserId string `protobuf:"bytes,2,opt,name=provider_user_id,json=providerUserId,proto3" json:"provider_user_id,omitempty"` // stable provider subject EmailAtLinkTime string `protobuf:"bytes,3,opt,name=email_at_link_time,json=emailAtLinkTime,proto3" json:"email_at_link_time,omitempty"` // email the provider asserted when linked LinkedAt int64 `protobuf:"varint,4,opt,name=linked_at,json=linkedAt,proto3" json:"linked_at,omitempty"` // epoch ms the link was created @@ -10835,7 +10843,7 @@ type LinkIdentityRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Code string `protobuf:"bytes,1,opt,name=code,proto3" json:"code,omitempty"` // authorization code from the provider callback RedirectUri string `protobuf:"bytes,2,opt,name=redirect_uri,json=redirectUri,proto3" json:"redirect_uri,omitempty"` - Provider string `protobuf:"bytes,3,opt,name=provider,proto3" json:"provider,omitempty"` // "google", "microsoft", "github", ... + Provider string `protobuf:"bytes,3,opt,name=provider,proto3" json:"provider,omitempty"` // "google", "microsoft", "github", "apple", ... CodeVerifier string `protobuf:"bytes,4,opt,name=code_verifier,json=codeVerifier,proto3" json:"code_verifier,omitempty"` // PKCE (optional) State string `protobuf:"bytes,5,opt,name=state,proto3" json:"state,omitempty"` // authorization callback state StateToken string `protobuf:"bytes,6,opt,name=state_token,json=stateToken,proto3" json:"state_token,omitempty"` // opaque token minted by BeginOAuthLogin @@ -11189,7 +11197,7 @@ const file_identity_v1_identity_proto_rawDesc = "" + "stateToken\x12#\n" + "\rcode_verifier\x18\x04 \x01(\tR\fcodeVerifier\x12\x1d\n" + "\n" + - "expires_in\x18\x05 \x01(\x05R\texpiresIn\"\xda\x01\n" + + "expires_in\x18\x05 \x01(\x05R\texpiresIn\"\x88\x02\n" + "\x11OAuthLoginRequest\x12\x12\n" + "\x04code\x18\x01 \x01(\tR\x04code\x12!\n" + "\fredirect_uri\x18\x02 \x01(\tR\vredirectUri\x12\x1a\n" + @@ -11198,7 +11206,8 @@ const file_identity_v1_identity_proto_rawDesc = "" + "\x06tenant\x18\x05 \x01(\tR\x06tenant\x12\x14\n" + "\x05state\x18\x06 \x01(\tR\x05state\x12\x1f\n" + "\vstate_token\x18\a \x01(\tR\n" + - "stateToken\"\xa8\x01\n" + + "stateToken\x12,\n" + + "\x12apple_user_payload\x18\b \x01(\tR\x10appleUserPayload\"\xa8\x01\n" + "\x12OAuthLoginResponse\x12%\n" + "\x04user\x18\x01 \x01(\v2\x11.identity.v1.UserR\x04user\x12!\n" + "\faccess_token\x18\x02 \x01(\tR\vaccessToken\x12#\n" + diff --git a/gen/openapi/identity.openapi.yaml b/gen/openapi/identity.openapi.yaml index 1def107f..731e29d6 100644 --- a/gen/openapi/identity.openapi.yaml +++ b/gen/openapi/identity.openapi.yaml @@ -3785,7 +3785,7 @@ components: provider: type: string title: provider - description: '"google", "microsoft", or "github"' + description: '"google", "microsoft", "github", or "apple"' tenant: type: string title: tenant @@ -4552,7 +4552,7 @@ components: provider: type: string title: provider - description: '"google", "microsoft", "github", ...' + description: '"google", "microsoft", "github", "apple", ...' codeVerifier: type: string title: code_verifier @@ -4588,7 +4588,7 @@ components: provider: type: string title: provider - description: '"google", "microsoft", "github", ...' + description: '"google", "microsoft", "github", "apple", ...' providerUserId: type: string title: provider_user_id @@ -4975,7 +4975,7 @@ components: provider: type: string title: provider - description: '"google", "microsoft", or "github"' + description: '"google", "microsoft", "github", or "apple"' codeVerifier: type: string title: code_verifier @@ -4992,6 +4992,10 @@ components: type: string title: state_token description: opaque token minted by BeginOAuthLogin + appleUserPayload: + type: string + title: apple_user_payload + description: one-time user payload from Apple's form_post callback title: OAuthLoginRequest additionalProperties: false identity.v1.OAuthLoginResponse: @@ -5256,7 +5260,7 @@ components: additionalProperties: false description: |- RedeemOAuthCode exchanges the single-use one-time code handed to the - SPA by the hosted OAuth callback (GET /oauth/callback/{provider} -> + SPA by the hosted OAuth callback (GET/POST /oauth/callback/{provider} -> 302 return_to?code=) for a backend-issued token pair. The code is single-use and short-lived; a replay returns CodeUnauthenticated. identity.v1.RedeemOAuthCodeResponse: diff --git a/internal/app/oauth.go b/internal/app/oauth.go index 1ecc71ef..d12b2cc7 100644 --- a/internal/app/oauth.go +++ b/internal/app/oauth.go @@ -8,10 +8,11 @@ import ( ) // buildOAuthRegistry constructs an oauth.Registry from the gateway -// configuration. A provider is registered only if BOTH the client ID -// and client secret are non-empty for that provider; this lets -// operators leave a provider's credentials unset to disable it (rather -// than gating each provider behind its own boolean). +// configuration. A provider is registered only if its required credentials +// are non-empty; this lets operators leave a provider's credentials unset +// to disable it (rather than gating each provider behind its own boolean). +// Apple requires client ID, team ID, key ID, and private key. Others +// require client ID and client secret. // // The returned registry is never nil so the AuthService can call // (*Registry).Len() unconditionally. @@ -40,13 +41,21 @@ func buildOAuthRegistry(cfg *config.Config, logger *zap.Logger) *oauth.Registry ClientSecret: cfg.GitHubClientSecret, })) } + if cfg.AppleClientID != "" && cfg.AppleTeamID != "" && cfg.AppleKeyID != "" && cfg.ApplePrivateKey != "" { + r.Register("apple", oauth.NewApple(oauth.AppleConfig{ + ClientID: cfg.AppleClientID, + TeamID: cfg.AppleTeamID, + KeyID: cfg.AppleKeyID, + PrivateKey: cfg.ApplePrivateKey, + })) + } if r.Len() == 0 { logger.Warn( "oauth_disabled_no_providers_configured", zap.String("hint", - "set GATEWAY_GOOGLE_CLIENT_ID/SECRET, GATEWAY_MICROSOFT_CLIENT_ID/SECRET, "+ - "or GATEWAY_GITHUB_CLIENT_ID/SECRET to enable OAuth login"), + "set GATEWAY_OAUTH_GOOGLE_CLIENT_ID/SECRET, GATEWAY_OAUTH_MICROSOFT_CLIENT_ID/SECRET, "+ + "GATEWAY_OAUTH_GITHUB_CLIENT_ID/SECRET, or GATEWAY_OAUTH_APPLE_... to enable OAuth login"), ) } else { logger.Info( diff --git a/internal/app/oauth_hosted_http.go b/internal/app/oauth_hosted_http.go index 77892103..aec1db1e 100644 --- a/internal/app/oauth_hosted_http.go +++ b/internal/app/oauth_hosted_http.go @@ -1,6 +1,8 @@ package app import ( + "crypto/rand" + "encoding/base64" "net/http" "net/url" "strings" @@ -13,7 +15,7 @@ import ( // hostedOAuthHandler serves the browser-facing hosted OAuth endpoints: // // GET /oauth/start/{provider}?return_to= -// GET /oauth/callback/{provider} +// GET/POST /oauth/callback/{provider} // // These are plain HTTP routes (not Connect RPCs) because the browser is // 302-redirected through them. They are thin wrappers over the existing @@ -30,6 +32,11 @@ type hostedOAuthHandler struct { logger *zap.Logger } +const ( + hostedOAuthCSRFCookiePrefix = "__Host-oauth_csrf_" + hostedOAuthCSRFMaxTokens = 16 +) + // register wires the hosted routes onto mux. It is a no-op when the // allowlist is empty (hosted flow disabled), leaving the routes // unregistered so they 404 — the headless RPCs are unaffected. @@ -62,8 +69,19 @@ func (h *hostedOAuthHandler) handleStart(w http.ResponseWriter, r *http.Request) return } + var b [16]byte + if _, err := rand.Read(b[:]); err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + csrfToken := base64.RawURLEncoding.EncodeToString(b[:]) + csrfTokens, ok := hostedOAuthCSRFTokens(r, provider) + if !ok { + csrfTokens = nil + } + redirectURI := h.callbackURL(r, provider) - result, err := h.auth.BeginHostedOAuth(r.Context(), provider, redirectURI, returnTo) + result, err := h.auth.BeginHostedOAuth(r.Context(), provider, redirectURI, returnTo, csrfToken) if err != nil { h.logger.Info("hosted_oauth_start_failed", zap.String("provider", provider), zap.Error(err)) status := http.StatusBadRequest @@ -76,36 +94,51 @@ func (h *hostedOAuthHandler) handleStart(w http.ResponseWriter, r *http.Request) // #nosec G710 -- the redirect target is the provider authorization // URL built server-side by BeginHostedOAuth from the registered // provider config, not from request input. + http.SetCookie(w, hostedOAuthCSRFCookie(provider, appendHostedOAuthCSRFToken(csrfTokens, csrfToken), 900)) http.Redirect(w, r, result.AuthorizationURL, http.StatusFound) } func (h *hostedOAuthHandler) handleCallback(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { + if r.Method != http.MethodGet && r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } + if r.Method == http.MethodPost { + r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MiB limit + } + if err := r.ParseForm(); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + provider := pathProvider(r.URL.Path, "/oauth/callback/") if provider == "" { http.Error(w, "provider is required", http.StatusBadRequest) return } - q := r.URL.Query() - state := q.Get("state") + state := r.FormValue("state") // Provider-side error (user denied consent, etc). We cannot recover // return_to without a valid state token, so there is nowhere safe to // redirect — surface a generic 400 and log server-side. - if provErr := q.Get("error"); provErr != "" { + if provErr := r.FormValue("error"); provErr != "" { h.logger.Info("hosted_oauth_provider_error", zap.String("provider", provider), zap.String("error", provErr)) http.Error(w, "oauth failed", http.StatusBadRequest) return } + csrfTokens, ok := hostedOAuthCSRFTokens(r, provider) + if !ok { + h.logger.Info("hosted_oauth_csrf_cookie_invalid", zap.String("provider", provider)) + http.Error(w, "oauth failed", http.StatusBadRequest) + return + } + result, err := h.auth.CompleteHostedOAuth( - r.Context(), provider, q.Get("code"), state, - clientIPFromRequest(r), r.UserAgent(), + r.Context(), provider, r.FormValue("code"), state, r.FormValue("user"), + clientIPFromRequest(r), r.UserAgent(), csrfTokens, ) if err != nil { // We do not have a verified return_to here (the state token failed @@ -121,9 +154,80 @@ func (h *hostedOAuthHandler) handleCallback(w http.ResponseWriter, r *http.Reque // tamper-proof hosted state token whose return_to was validated // against the GATEWAY_OAUTH_ALLOWED_RETURN_URLS allowlist at /start // time. It is not raw request input. + remainingCSRFTokens := removeHostedOAuthCSRFToken(csrfTokens, result.CSRFToken) + maxAge := 900 + if len(remainingCSRFTokens) == 0 { + maxAge = -1 + } + http.SetCookie(w, hostedOAuthCSRFCookie(provider, remainingCSRFTokens, maxAge)) http.Redirect(w, r, appendQueryParam(result.ReturnTo, "code", result.Code), http.StatusFound) } +func hostedOAuthCSRFCookie(provider string, tokens []string, maxAge int) *http.Cookie { + sameSite := http.SameSiteLaxMode + if provider == "apple" { + sameSite = http.SameSiteNoneMode + } + return &http.Cookie{ // #nosec G124 -- Secure and HttpOnly are fixed; SameSite is Lax or None for Apple's cross-site form POST. + Name: hostedOAuthCSRFCookieName(provider), + Value: strings.Join(tokens, "."), + Path: "/", + MaxAge: maxAge, + HttpOnly: true, + Secure: true, + SameSite: sameSite, + } +} + +func hostedOAuthCSRFCookieName(provider string) string { + return hostedOAuthCSRFCookiePrefix + provider +} + +func hostedOAuthCSRFTokens(r *http.Request, provider string) ([]string, bool) { + var value string + found := false + for _, cookie := range r.Cookies() { + if cookie.Name != hostedOAuthCSRFCookieName(provider) { + continue + } + if found { + return nil, false + } + found = true + value = cookie.Value + } + if !found { + return nil, false + } + tokens := strings.Split(value, ".") + if len(tokens) > hostedOAuthCSRFMaxTokens { + return nil, false + } + for _, token := range tokens { + if token == "" { + return nil, false + } + } + return tokens, true +} + +func appendHostedOAuthCSRFToken(tokens []string, token string) []string { + if len(tokens) == hostedOAuthCSRFMaxTokens { + tokens = tokens[1:] + } + return append(tokens, token) +} + +func removeHostedOAuthCSRFToken(tokens []string, token string) []string { + remaining := make([]string, 0, len(tokens)) + for _, candidate := range tokens { + if candidate != token { + remaining = append(remaining, candidate) + } + } + return remaining +} + // callbackURL reconstructs this server's single redirect URI for the // provider from the incoming request, honoring X-Forwarded-Proto / // X-Forwarded-Host when present (set by a trusted reverse proxy). The diff --git a/internal/app/oauth_hosted_http_test.go b/internal/app/oauth_hosted_http_test.go index e2c2983e..9edc3217 100644 --- a/internal/app/oauth_hosted_http_test.go +++ b/internal/app/oauth_hosted_http_test.go @@ -34,7 +34,7 @@ func (p *appTestStubProvider) AuthorizationURL(_ context.Context, redirectURI, s return u.String(), nil } -func (p *appTestStubProvider) Exchange(_ context.Context, _, _ string) (*oauth.Identity, error) { +func (p *appTestStubProvider) Exchange(_ context.Context, _ oauth.ExchangeParams) (*oauth.Identity, error) { if p.err != nil { return nil, p.err } @@ -90,8 +90,12 @@ func newHostedTestHandler(t *testing.T, allowlist string, reg *oauth.Registry) h } func hostedTestRegistry(p oauth.Exchanger) *oauth.Registry { + return hostedTestRegistryFor("google", p) +} + +func hostedTestRegistryFor(provider string, p oauth.Exchanger) *oauth.Registry { reg := oauth.NewRegistry() - reg.Register("google", p) + reg.Register(provider, p) return reg } @@ -116,6 +120,36 @@ func TestHostedHTTP_StartHappyPath(t *testing.T) { } } +func TestHostedHTTP_StartSetsCSRFCookie(t *testing.T) { + for _, tt := range []struct { + provider string + sameSite http.SameSite + }{ + {provider: "google", sameSite: http.SameSiteLaxMode}, + {provider: "apple", sameSite: http.SameSiteNoneMode}, + } { + t.Run(tt.provider, func(t *testing.T) { + h := newHostedTestHandler(t, "https://app.test/", hostedTestRegistryFor(tt.provider, &appTestStubProvider{})) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, + "/oauth/start/"+tt.provider+"?return_to="+url.QueryEscape("https://app.test/finish"), nil)) + + if rr.Code != http.StatusFound { + t.Fatalf("status = %d, want 302", rr.Code) + } + cookies := rr.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("cookie count = %d, want 1", len(cookies)) + } + cookie := cookies[0] + if cookie.Name != hostedOAuthCSRFCookieName(tt.provider) || cookie.Path != "/" || cookie.Domain != "" || + !cookie.HttpOnly || !cookie.Secure || cookie.SameSite != tt.sameSite { + t.Fatalf("csrf cookie = %#v", cookie) + } + }) + } +} + func TestHostedHTTP_StartRejectsReturnTo(t *testing.T) { h := newHostedTestHandler(t, "https://app.test/", hostedTestRegistry(&appTestStubProvider{})) rr := httptest.NewRecorder() @@ -169,8 +203,91 @@ func TestHostedHTTP_FullStartCallback(t *testing.T) { // Callback with the state token + a code. cbRR := httptest.NewRecorder() - h.ServeHTTP(cbRR, httptest.NewRequest(http.MethodGet, - "/oauth/callback/google?state="+url.QueryEscape(stateToken)+"&code=auth-xyz", nil)) + cbReq := httptest.NewRequest(http.MethodGet, "/oauth/callback/google?state="+url.QueryEscape(stateToken)+"&code=auth-xyz", nil) + for _, c := range startRR.Result().Cookies() { + cbReq.AddCookie(c) + } + h.ServeHTTP(cbRR, cbReq) + if cbRR.Code != http.StatusFound { + t.Fatalf("callback status = %d, want 302; body=%q", cbRR.Code, cbRR.Body.String()) + } + redir := cbRR.Header().Get("Location") + if !strings.HasPrefix(redir, "https://app.test/finish") { + t.Fatalf("callback redirect = %q", redir) + } + cb, _ := url.Parse(redir) + if cb.Query().Get("code") == "" { + t.Fatal("callback redirect carried no one-time code") + } +} + +func TestHostedHTTP_ConcurrentStartsCompleteIndependently(t *testing.T) { + h := newHostedTestHandler(t, "https://app.test/", hostedTestRegistry(&appTestStubProvider{})) + + firstStartRR := httptest.NewRecorder() + h.ServeHTTP(firstStartRR, httptest.NewRequest(http.MethodGet, + "/oauth/start/google?return_to="+url.QueryEscape("https://app.test/finish"), nil)) + firstLocation, _ := url.Parse(firstStartRR.Header().Get("Location")) + firstState := firstLocation.Query().Get("state") + firstCookie := firstStartRR.Result().Cookies()[0] + + secondStartRR := httptest.NewRecorder() + secondStartReq := httptest.NewRequest(http.MethodGet, + "/oauth/start/google?return_to="+url.QueryEscape("https://app.test/finish"), nil) + secondStartReq.AddCookie(firstCookie) + h.ServeHTTP(secondStartRR, secondStartReq) + secondLocation, _ := url.Parse(secondStartRR.Header().Get("Location")) + secondState := secondLocation.Query().Get("state") + secondCookie := secondStartRR.Result().Cookies()[0] + if got := len(strings.Split(secondCookie.Value, ".")); got != 2 { + t.Fatalf("csrf token count = %d, want 2", got) + } + + firstCallbackRR := httptest.NewRecorder() + firstCallbackReq := httptest.NewRequest(http.MethodGet, + "/oauth/callback/google?state="+url.QueryEscape(firstState)+"&code=auth-first", nil) + firstCallbackReq.AddCookie(secondCookie) + h.ServeHTTP(firstCallbackRR, firstCallbackReq) + if firstCallbackRR.Code != http.StatusFound { + t.Fatalf("first callback status = %d, want 302; body=%q", firstCallbackRR.Code, firstCallbackRR.Body.String()) + } + remainingCookie := firstCallbackRR.Result().Cookies()[0] + if got := len(strings.Split(remainingCookie.Value, ".")); got != 1 { + t.Fatalf("remaining csrf token count = %d, want 1", got) + } + + secondCallbackRR := httptest.NewRecorder() + secondCallbackReq := httptest.NewRequest(http.MethodGet, + "/oauth/callback/google?state="+url.QueryEscape(secondState)+"&code=auth-second", nil) + secondCallbackReq.AddCookie(remainingCookie) + h.ServeHTTP(secondCallbackRR, secondCallbackReq) + if secondCallbackRR.Code != http.StatusFound { + t.Fatalf("second callback status = %d, want 302; body=%q", secondCallbackRR.Code, secondCallbackRR.Body.String()) + } +} + +func TestHostedHTTP_FullStartCallback_FormPost(t *testing.T) { + h := newHostedTestHandler(t, "https://app.test/", hostedTestRegistry(&appTestStubProvider{})) + + // Start to obtain a valid state token. + startRR := httptest.NewRecorder() + h.ServeHTTP(startRR, httptest.NewRequest(http.MethodGet, + "/oauth/start/google?return_to="+url.QueryEscape("https://app.test/finish"), nil)) + loc, _ := url.Parse(startRR.Header().Get("Location")) + stateToken := loc.Query().Get("state") + + // Callback with the state token + a code via POST form body. + cbRR := httptest.NewRecorder() + form := url.Values{} + form.Set("state", stateToken) + form.Set("code", "auth-xyz-form") + req := httptest.NewRequest(http.MethodPost, "/oauth/callback/google", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + for _, c := range startRR.Result().Cookies() { + req.AddCookie(c) + } + h.ServeHTTP(cbRR, req) + if cbRR.Code != http.StatusFound { t.Fatalf("callback status = %d, want 302; body=%q", cbRR.Code, cbRR.Body.String()) } @@ -184,6 +301,60 @@ func TestHostedHTTP_FullStartCallback(t *testing.T) { } } +func TestHostedHTTP_CallbackRejectsInvalidCSRFCookie(t *testing.T) { + h := newHostedTestHandler(t, "https://app.test/", hostedTestRegistry(&appTestStubProvider{})) + startRR := httptest.NewRecorder() + h.ServeHTTP(startRR, httptest.NewRequest(http.MethodGet, + "/oauth/start/google?return_to="+url.QueryEscape("https://app.test/finish"), nil)) + loc, _ := url.Parse(startRR.Header().Get("Location")) + stateToken := loc.Query().Get("state") + + for _, tt := range []struct { + name string + addCookies func(*http.Request) + }{ + {name: "missing"}, + {name: "mismatched", addCookies: func(req *http.Request) { + req.AddCookie(&http.Cookie{ + Name: hostedOAuthCSRFCookieName("google"), + Value: "wrong", + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }) + }}, + {name: "duplicate", addCookies: func(req *http.Request) { + req.AddCookie(&http.Cookie{ + Name: hostedOAuthCSRFCookieName("google"), + Value: "first", + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }) + req.AddCookie(&http.Cookie{ + Name: hostedOAuthCSRFCookieName("google"), + Value: "second", + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }) + }}, + } { + t.Run(tt.name, func(t *testing.T) { + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, + "/oauth/callback/google?state="+url.QueryEscape(stateToken)+"&code=auth-xyz", nil) + if tt.addCookies != nil { + tt.addCookies(req) + } + h.ServeHTTP(rr, req) + if rr.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", rr.Code) + } + }) + } +} + func TestHostedHTTP_CallbackProviderError(t *testing.T) { h := newHostedTestHandler(t, "https://app.test/", hostedTestRegistry(&appTestStubProvider{})) rr := httptest.NewRecorder() @@ -207,7 +378,7 @@ func TestHostedHTTP_CallbackBadState(t *testing.T) { func TestHostedHTTP_CallbackMethodNotAllowed(t *testing.T) { h := newHostedTestHandler(t, "https://app.test/", hostedTestRegistry(&appTestStubProvider{})) rr := httptest.NewRecorder() - h.ServeHTTP(rr, httptest.NewRequest(http.MethodPost, "/oauth/callback/google", nil)) + h.ServeHTTP(rr, httptest.NewRequest(http.MethodPut, "/oauth/callback/google", nil)) if rr.Code != http.StatusMethodNotAllowed { t.Fatalf("status = %d, want 405", rr.Code) } diff --git a/internal/app/wiring_test.go b/internal/app/wiring_test.go index 0a9ca18c..e1e4f267 100644 --- a/internal/app/wiring_test.go +++ b/internal/app/wiring_test.go @@ -29,22 +29,26 @@ func TestBuildOAuthRegistry(t *testing.T) { cfg := &config.Config{ GoogleClientID: "google-client", MicrosoftClientID: "microsoft-client", - MicrosoftTenantID: "tenant", + MicrosoftTenantID: "common", GitHubClientID: "github-client", + AppleClientID: "apple-client", } cfg.GoogleClientSecret = testCredential("google") cfg.MicrosoftClientSecret = testCredential("microsoft") cfg.GitHubClientSecret = testCredential("github") + cfg.ApplePrivateKey = testCredential("apple") + cfg.AppleTeamID = "team" + cfg.AppleKeyID = "key" registry := buildOAuthRegistry(cfg, zap.NewNop()) - if registry.Len() != 3 { + if registry.Len() != 4 { t.Fatalf("registry Len = %d", registry.Len()) } got := make(map[string]bool) for _, provider := range registry.Providers() { got[provider] = true } - for _, provider := range []string{"google", "microsoft", "github"} { + for _, provider := range []string{"google", "microsoft", "github", "apple"} { if !got[provider] { t.Fatalf("provider %q not registered; got %v", provider, registry.Providers()) } diff --git a/internal/config/config.go b/internal/config/config.go index ce14463b..8782d1a1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -218,16 +218,20 @@ type Config struct { MicrosoftTenantID string GitHubClientID string GitHubClientSecret string + AppleClientID string + AppleTeamID string + AppleKeyID string + ApplePrivateKey string // OAuthAllowedReturnURLs is the comma-separated allowlist of app URLs // the hosted OAuth flow may redirect back to (the `return_to` param of - // GET /oauth/start/{provider}). Each entry is an exact origin or a URL - // prefix; a return_to matches when it equals an entry or begins with - // an entry. Validation is fail-closed: a return_to that matches no - // entry is rejected with 400. + // GET /oauth/start/{provider}). Each entry is an exact origin or a path + // prefix. A return_to must match the configured origin and, for path + // entries, the configured path or one of its descendants. Validation is + // fail-closed: a return_to that matches no entry is rejected with 400. // - // Empty disables the hosted flow entirely — GET /oauth/start and GET - // /oauth/callback return 404. The headless BeginOAuthLogin / OAuthLogin + // Empty disables the hosted flow entirely — GET /oauth/start and + // GET/POST /oauth/callback return 404. The headless BeginOAuthLogin / OAuthLogin // RPCs are unaffected. Driven by GATEWAY_OAUTH_ALLOWED_RETURN_URLS. OAuthAllowedReturnURLs string @@ -564,13 +568,17 @@ func Load() *Config { ProjectResolutionCacheTTLSeconds: envInt("GATEWAY_PROJECT_RESOLUTION_CACHE_TTL_SECONDS", 30), ProjectResolutionCacheMaxEntries: envInt("GATEWAY_PROJECT_RESOLUTION_CACHE_MAX_ENTRIES", 10000), - GoogleClientID: envStr("GATEWAY_OAUTH_GOOGLE_CLIENT_ID", envStr("GATEWAY_GOOGLE_CLIENT_ID", "")), - GoogleClientSecret: envStr("GATEWAY_OAUTH_GOOGLE_CLIENT_SECRET", envStr("GATEWAY_GOOGLE_CLIENT_SECRET", "")), - MicrosoftClientID: envStr("GATEWAY_OAUTH_MICROSOFT_CLIENT_ID", envStr("GATEWAY_MICROSOFT_CLIENT_ID", "")), - MicrosoftClientSecret: envStr("GATEWAY_OAUTH_MICROSOFT_CLIENT_SECRET", envStr("GATEWAY_MICROSOFT_CLIENT_SECRET", "")), + GoogleClientID: envStr("GATEWAY_OAUTH_GOOGLE_CLIENT_ID", ""), + GoogleClientSecret: envStr("GATEWAY_OAUTH_GOOGLE_CLIENT_SECRET", ""), + MicrosoftClientID: envStr("GATEWAY_OAUTH_MICROSOFT_CLIENT_ID", ""), + MicrosoftClientSecret: envStr("GATEWAY_OAUTH_MICROSOFT_CLIENT_SECRET", ""), MicrosoftTenantID: envStr("GATEWAY_MICROSOFT_TENANT_ID", ""), GitHubClientID: envStr("GATEWAY_OAUTH_GITHUB_CLIENT_ID", ""), GitHubClientSecret: envStr("GATEWAY_OAUTH_GITHUB_CLIENT_SECRET", ""), + AppleClientID: envStr("GATEWAY_OAUTH_APPLE_CLIENT_ID", ""), + AppleTeamID: envStr("GATEWAY_OAUTH_APPLE_TEAM_ID", ""), + AppleKeyID: envStr("GATEWAY_OAUTH_APPLE_KEY_ID", ""), + ApplePrivateKey: envStr("GATEWAY_OAUTH_APPLE_PRIVATE_KEY", ""), OAuthAllowedReturnURLs: envStr("GATEWAY_OAUTH_ALLOWED_RETURN_URLS", ""), diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 8c461933..34370fdb 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -89,6 +89,18 @@ func TestLoad_Defaults(t *testing.T) { if cfg.SweeperGraceSeconds != 60 { t.Errorf("SweeperGraceSeconds: want 60, got %d", cfg.SweeperGraceSeconds) } + if cfg.AppleClientID != "" { + t.Errorf("AppleClientID: want empty default, got %q", cfg.AppleClientID) + } + if cfg.AppleTeamID != "" { + t.Errorf("AppleTeamID: want empty default, got %q", cfg.AppleTeamID) + } + if cfg.AppleKeyID != "" { + t.Errorf("AppleKeyID: want empty default, got %q", cfg.AppleKeyID) + } + if cfg.ApplePrivateKey != "" { + t.Errorf("ApplePrivateKey: want empty default, got %q", cfg.ApplePrivateKey) + } } // TestLoad_SweeperDisabledWhenIntervalZero asserts the documented @@ -119,6 +131,10 @@ func TestLoad_OverrideFromEnv(t *testing.T) { t.Setenv("GATEWAY_PASSWORD_RESET_ENABLED", "false") t.Setenv("GATEWAY_LOGIN_MAX_FAILED_ATTEMPTS", "10") t.Setenv("GATEWAY_TOTP_ISSUER", "My Corp") + t.Setenv("GATEWAY_OAUTH_APPLE_CLIENT_ID", "apple-client") + t.Setenv("GATEWAY_OAUTH_APPLE_TEAM_ID", "apple-team") + t.Setenv("GATEWAY_OAUTH_APPLE_KEY_ID", "apple-key") + t.Setenv("GATEWAY_OAUTH_APPLE_PRIVATE_KEY", "apple-private") cfg := Load() @@ -155,6 +171,18 @@ func TestLoad_OverrideFromEnv(t *testing.T) { if cfg.TOTPIssuer != "My Corp" { t.Errorf("TOTPIssuer: want 'My Corp', got %q", cfg.TOTPIssuer) } + if cfg.AppleClientID != "apple-client" { + t.Errorf("AppleClientID: want apple-client, got %q", cfg.AppleClientID) + } + if cfg.AppleTeamID != "apple-team" { + t.Errorf("AppleTeamID: want apple-team, got %q", cfg.AppleTeamID) + } + if cfg.AppleKeyID != "apple-key" { + t.Errorf("AppleKeyID: want apple-key, got %q", cfg.AppleKeyID) + } + if cfg.ApplePrivateKey != "apple-private" { + t.Errorf("ApplePrivateKey: want apple-private, got %q", cfg.ApplePrivateKey) + } } // TestEnvStr_Default verifies envStr returns the default for unset vars. diff --git a/internal/connect/handler_auth.go b/internal/connect/handler_auth.go index f8efa36f..4070ed8f 100644 --- a/internal/connect/handler_auth.go +++ b/internal/connect/handler_auth.go @@ -49,17 +49,17 @@ func (h *IdentityHandler) OAuthLogin( ipAddr := clientIP(req.Header()) userAgent := clientUserAgent(req.Header()) - result, err := h.auth.OAuthLogin( - ctx, - req.Msg.Code, - req.Msg.Provider, - req.Msg.RedirectUri, - req.Msg.CodeVerifier, - req.Msg.State, - req.Msg.StateToken, - ipAddr, - userAgent, - ) + result, err := h.auth.OAuthLogin(ctx, service.OAuthLoginParams{ + Code: req.Msg.Code, + Provider: req.Msg.Provider, + RedirectURI: req.Msg.RedirectUri, + CodeVerifier: req.Msg.CodeVerifier, + State: req.Msg.State, + StateToken: req.Msg.StateToken, + AppleUserPayload: req.Msg.AppleUserPayload, + IPAddr: ipAddr, + UserAgent: userAgent, + }) if err != nil { return nil, toConnectError(err) } diff --git a/internal/connect/handlers_test.go b/internal/connect/handlers_test.go index 6b303888..06950795 100644 --- a/internal/connect/handlers_test.go +++ b/internal/connect/handlers_test.go @@ -491,7 +491,7 @@ func TestRedeemOAuthCode_DisabledUnavailable(t *testing.T) { type connectOAuthExchanger struct{} -func (connectOAuthExchanger) Exchange(context.Context, string, string) (*oauth.Identity, error) { +func (connectOAuthExchanger) Exchange(_ context.Context, _ oauth.ExchangeParams) (*oauth.Identity, error) { return &oauth.Identity{ Provider: "google", ProviderUserID: "connect-user", diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 294a21fa..f813df51 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -65,7 +65,7 @@ var AuthExemptPaths = map[string]bool{ } // hostedOAuthPrefix is the path prefix for the browser-facing hosted -// OAuth routes (GET /oauth/start/{provider}, GET /oauth/callback/ +// OAuth routes (GET /oauth/start/{provider}, GET/POST /oauth/callback/ // {provider}). These are unauthenticated by design — the user is mid // sign-in and has no JWT yet — so they are exempt as a prefix rather // than per-exact-path (the {provider} segment varies). diff --git a/internal/observability/clients.go b/internal/observability/clients.go index d280629a..3a7016e8 100644 --- a/internal/observability/clients.go +++ b/internal/observability/clients.go @@ -83,12 +83,12 @@ func WrapOAuthExchanger(provider string, e oauth.Exchanger) oauth.Exchanger { return &base } -func (t *TracedExchanger) Exchange(ctx context.Context, code, redirectURI string) (*oauth.Identity, error) { +func (t *TracedExchanger) Exchange(ctx context.Context, params oauth.ExchangeParams) (*oauth.Identity, error) { ctx, end := StartClient( ctx, "oauth.Exchange", attribute.String("oauth.provider", t.provider), ) - id, err := t.inner.Exchange(ctx, code, redirectURI) + id, err := t.inner.Exchange(ctx, params) end(err) return id, err } diff --git a/internal/observability/clients_test.go b/internal/observability/clients_test.go index 81cd6382..77a35646 100644 --- a/internal/observability/clients_test.go +++ b/internal/observability/clients_test.go @@ -67,7 +67,7 @@ type fakeExchanger struct { id *oauth.Identity } -func (f *fakeExchanger) Exchange(context.Context, string, string) (*oauth.Identity, error) { +func (f *fakeExchanger) Exchange(_ context.Context, _ oauth.ExchangeParams) (*oauth.Identity, error) { return f.id, nil } @@ -94,7 +94,7 @@ func TestWrapOAuthExchanger(t *testing.T) { if _, ok := wrapped.(oauth.Authorizer); ok { t.Errorf("plain exchanger wrapper should not satisfy Authorizer") } - id, err := wrapped.Exchange(context.Background(), "code", "https://r") + id, err := wrapped.Exchange(context.Background(), oauth.ExchangeParams{Code: "code", RedirectURI: "https://r"}) if err != nil || id.Email != "x@y" { t.Errorf("Exchange: %v %v", id, err) } diff --git a/internal/service/auth_link_identity.go b/internal/service/auth_link_identity.go index 7783b6b7..be4ec4e2 100644 --- a/internal/service/auth_link_identity.go +++ b/internal/service/auth_link_identity.go @@ -32,7 +32,14 @@ func (s *AuthService) LinkIdentity( } provider = strings.ToLower(strings.TrimSpace(provider)) - identity, err := s.verifyOAuthExchange(ctx, code, provider, redirectURI, codeVerifier, state, stateToken) + identity, err := s.verifyOAuthExchange(ctx, OAuthLoginParams{ + Code: code, + Provider: provider, + RedirectURI: redirectURI, + CodeVerifier: codeVerifier, + State: state, + StateToken: stateToken, + }) if err != nil { if errors.Is(err, errOAuthExchangeFailed) { s.audit.Log( diff --git a/internal/service/auth_login.go b/internal/service/auth_login.go index c02439b1..ddb319d7 100644 --- a/internal/service/auth_login.go +++ b/internal/service/auth_login.go @@ -522,6 +522,18 @@ func (s *AuthService) BeginOAuthLogin( }, nil } +type OAuthLoginParams struct { + Code string + Provider string + RedirectURI string + CodeVerifier string + State string + StateToken string + AppleUserPayload string + IPAddr string + UserAgent string +} + // OAuthLogin performs the full OAuth code-exchange flow: it looks up // the registered Exchanger for the provider, swaps the code for a // verified Identity, then upserts the local user and issues tokens. @@ -531,10 +543,10 @@ func (s *AuthService) BeginOAuthLogin( // refresh tokens are discarded — they are not persisted. func (s *AuthService) OAuthLogin( ctx context.Context, - code, provider, redirectURI, codeVerifier, state, stateToken, ipAddr, userAgent string, + params OAuthLoginParams, ) (*LoginResult, error) { - provider = strings.ToLower(strings.TrimSpace(provider)) - identity, err := s.verifyOAuthExchange(ctx, code, provider, redirectURI, codeVerifier, state, stateToken) + provider := strings.ToLower(strings.TrimSpace(params.Provider)) + identity, err := s.verifyOAuthExchange(ctx, params) if err != nil { if errors.Is(err, errOAuthExchangeFailed) { s.logger.Info( @@ -543,7 +555,7 @@ func (s *AuthService) OAuthLogin( ) s.audit.Log( ctx, audit.EventOAuthLogin, - audit.WithIP(ipAddr), audit.WithUserAgent(userAgent), + audit.WithIP(params.IPAddr), audit.WithUserAgent(params.UserAgent), audit.WithSuccess(false), audit.WithDetails(map[string]any{ "provider": provider, @@ -565,7 +577,7 @@ func (s *AuthService) OAuthLogin( return nil, err } - if err := s.checkAccountStatus(ctx, user, ipAddr, userAgent); err != nil { + if err := s.checkAccountStatus(ctx, user, params.IPAddr, params.UserAgent); err != nil { return nil, err } @@ -590,14 +602,14 @@ func (s *AuthService) OAuthLogin( zap.String("user_id", user.ID), ) - accessToken, refreshToken, err := s.issueTokens(ctx, user, ipAddr, userAgent) + accessToken, refreshToken, err := s.issueTokens(ctx, user, params.IPAddr, params.UserAgent) if err != nil { return nil, err } s.audit.Log( ctx, audit.EventOAuthLogin, - audit.WithActor(user.ID), audit.WithIP(ipAddr), audit.WithUserAgent(userAgent), + audit.WithActor(user.ID), audit.WithIP(params.IPAddr), audit.WithUserAgent(params.UserAgent), audit.WithSuccess(true), audit.WithDetails(map[string]any{ "provider": provider, @@ -648,16 +660,17 @@ var errOAuthExchangeFailed = errors.New("oauth code exchange failed") // exchange failure it returns an error wrapping errOAuthExchangeFailed. func (s *AuthService) verifyOAuthExchange( ctx context.Context, - code, provider, redirectURI, codeVerifier, state, stateToken string, + params OAuthLoginParams, ) (*oauth.Identity, error) { if s.oauthRegistry == nil || s.oauthRegistry.Len() == 0 { return nil, ErrOAuthDisabled } - redirectURI = strings.TrimSpace(redirectURI) + redirectURI := strings.TrimSpace(params.RedirectURI) + provider := strings.ToLower(strings.TrimSpace(params.Provider)) if provider == "" { return nil, fmt.Errorf("%w: provider is required", ErrInvalidArgument) } - if strings.TrimSpace(code) == "" { + if strings.TrimSpace(params.Code) == "" { return nil, fmt.Errorf("%w: code is required", ErrInvalidArgument) } if redirectURI == "" { @@ -669,14 +682,15 @@ func (s *AuthService) verifyOAuthExchange( return nil, fmt.Errorf("%w: unknown oauth provider %q", ErrInvalidArgument, provider) } - if strings.TrimSpace(stateToken) != "" { + codeVerifier := params.CodeVerifier + if strings.TrimSpace(params.StateToken) != "" { claims, err := oauth.VerifyStateToken( - stateToken, + params.StateToken, s.signer, provider, redirectURI, - state, - codeVerifier, + params.State, + params.CodeVerifier, s.nowFunc().UTC(), ) if err != nil { @@ -690,11 +704,12 @@ func (s *AuthService) verifyOAuthExchange( codeVerifier = claims.CodeVerifier } - if strings.TrimSpace(codeVerifier) != "" { - ctx = oauth.WithCodeVerifier(ctx, codeVerifier) - } - - identity, err := exchanger.Exchange(ctx, code, redirectURI) + identity, err := exchanger.Exchange(ctx, oauth.ExchangeParams{ + Code: params.Code, + RedirectURI: redirectURI, + CodeVerifier: codeVerifier, + AppleUserPayload: params.AppleUserPayload, + }) if err != nil { return nil, fmt.Errorf("%w: %w", errOAuthExchangeFailed, err) } diff --git a/internal/service/auth_oauth_hosted.go b/internal/service/auth_oauth_hosted.go index fbcdba66..5f66faeb 100644 --- a/internal/service/auth_oauth_hosted.go +++ b/internal/service/auth_oauth_hosted.go @@ -2,6 +2,7 @@ package service import ( "context" + "crypto/subtle" "errors" "fmt" "strings" @@ -38,7 +39,7 @@ type HostedOAuthBeginResult struct { // BeginOAuthLogin uses — there is no forked authorization path. func (s *AuthService) BeginHostedOAuth( ctx context.Context, - provider, redirectURI, returnTo string, + provider, redirectURI, returnTo, csrfToken string, ) (*HostedOAuthBeginResult, error) { if s.oauthRegistry == nil || s.oauthRegistry.Len() == 0 { return nil, ErrOAuthDisabled @@ -53,6 +54,9 @@ func (s *AuthService) BeginHostedOAuth( if strings.TrimSpace(returnTo) == "" { return nil, fmt.Errorf("%w: return_to is required", ErrInvalidArgument) } + if strings.TrimSpace(csrfToken) == "" { + return nil, fmt.Errorf("%w: csrf_token is required", ErrInvalidArgument) + } exchanger, ok := s.oauthRegistry.Get(provider) if !ok { @@ -79,6 +83,7 @@ func (s *AuthService) BeginHostedOAuth( returnTo, state, codeVerifier, + csrfToken, oauthStateTokenExpiry, s.nowFunc().UTC(), ) @@ -107,8 +112,9 @@ func (s *AuthService) BeginHostedOAuth( // validated return_to plus the freshly-minted one-time code the callback // appends as ?code=. type HostedOAuthCallbackResult struct { - ReturnTo string - Code string + ReturnTo string + Code string + CSRFToken string } // CompleteHostedOAuth runs the hosted callback: it verifies the signed @@ -122,13 +128,22 @@ type HostedOAuthCallbackResult struct { // only to cross-check the token's provider claim. func (s *AuthService) CompleteHostedOAuth( ctx context.Context, - providerFromPath, code, stateToken, ipAddr, userAgent string, + providerFromPath, code, stateToken, appleUserPayload, ipAddr, userAgent string, + csrfTokens []string, ) (*HostedOAuthCallbackResult, error) { claims, err := oauth.VerifyHostedStateToken(stateToken, s.signer, s.nowFunc().UTC()) if err != nil { s.logger.Info("hosted_oauth_state_validation_failed", zap.Error(err)) return nil, fmt.Errorf("%w: invalid oauth state", ErrUnauthenticated) } + matchedCSRFToken := 0 + for _, token := range csrfTokens { + matchedCSRFToken |= subtle.ConstantTimeCompare([]byte(claims.CSRFToken), []byte(token)) + } + if matchedCSRFToken != 1 { + s.logger.Info("hosted_oauth_csrf_mismatch") + return nil, fmt.Errorf("%w: csrf mismatch", ErrUnauthenticated) + } if want := strings.ToLower(strings.TrimSpace(providerFromPath)); want != "" && claims.Provider != want { s.logger.Info("hosted_oauth_provider_mismatch", zap.String("path_provider", want), zap.String("token_provider", claims.Provider)) @@ -140,17 +155,17 @@ func (s *AuthService) CompleteHostedOAuth( // verifier so PKCE completes. OAuthLogin upserts the user and mints // the identity token pair internally; we discard those tokens and // hand back a one-time code instead — the SPA re-mints via redeem. - result, err := s.OAuthLogin( - ctx, - code, - claims.Provider, - claims.RedirectURI, - claims.CodeVerifier, - "", // state already verified against the hosted token - "", // no headless state token in the hosted flow - ipAddr, - userAgent, - ) + result, err := s.OAuthLogin(ctx, OAuthLoginParams{ + Code: code, + Provider: claims.Provider, + RedirectURI: claims.RedirectURI, + CodeVerifier: claims.CodeVerifier, + State: "", // state already verified against the hosted token + StateToken: "", // no headless state token in the hosted flow + AppleUserPayload: appleUserPayload, + IPAddr: ipAddr, + UserAgent: userAgent, + }) if err != nil { return nil, err } @@ -160,7 +175,7 @@ func (s *AuthService) CompleteHostedOAuth( return nil, err } - return &HostedOAuthCallbackResult{ReturnTo: claims.ReturnTo, Code: otc}, nil + return &HostedOAuthCallbackResult{ReturnTo: claims.ReturnTo, Code: otc, CSRFToken: claims.CSRFToken}, nil } // mintOAuthOneTimeCode generates an opaque code, stores its hash bound diff --git a/internal/service/auth_oauth_hosted_test.go b/internal/service/auth_oauth_hosted_test.go index 82e31e06..53e65438 100644 --- a/internal/service/auth_oauth_hosted_test.go +++ b/internal/service/auth_oauth_hosted_test.go @@ -28,14 +28,14 @@ func TestHostedOAuth_BeginCompleteRedeem(t *testing.T) { ctx := context.Background() begin, err := svc.BeginHostedOAuth(ctx, "google", - "https://identity.test/oauth/callback/google", "https://app.test/finish") + "https://identity.test/oauth/callback/google", "https://app.test/finish", "csrf-123") require.NoError(t, err) require.NotEmpty(t, begin.AuthorizationURL) stateToken := stateTokenFromAuthURL(t, begin.AuthorizationURL) cb, err := svc.CompleteHostedOAuth(ctx, "google", fakeOAuthCode("hosted@example.com", "Hosted", "", "google"), - stateToken, "1.2.3.4", "test-agent") + stateToken, "", "1.2.3.4", "test-agent", []string{"csrf-123"}) require.NoError(t, err) assert.Equal(t, "https://app.test/finish", cb.ReturnTo) require.NotEmpty(t, cb.Code) @@ -58,23 +58,23 @@ func TestHostedOAuth_Begin_InputErrors(t *testing.T) { svc := newTestAuthService(t, repo) ctx := context.Background() - _, err := svc.BeginHostedOAuth(ctx, "", "https://identity.test/cb", "https://app.test/") + _, err := svc.BeginHostedOAuth(ctx, "", "https://identity.test/cb", "https://app.test/", "csrf-123") assert.True(t, errors.Is(err, ErrInvalidArgument)) - _, err = svc.BeginHostedOAuth(ctx, "google", "", "https://app.test/") + _, err = svc.BeginHostedOAuth(ctx, "google", "", "https://app.test/", "csrf-123") assert.True(t, errors.Is(err, ErrInvalidArgument)) - _, err = svc.BeginHostedOAuth(ctx, "google", "https://identity.test/cb", "") + _, err = svc.BeginHostedOAuth(ctx, "google", "https://identity.test/cb", "", "csrf-123") assert.True(t, errors.Is(err, ErrInvalidArgument)) - _, err = svc.BeginHostedOAuth(ctx, "unknown", "https://identity.test/cb", "https://app.test/") + _, err = svc.BeginHostedOAuth(ctx, "unknown", "https://identity.test/cb", "https://app.test/", "csrf-123") assert.True(t, errors.Is(err, ErrInvalidArgument)) } func TestHostedOAuth_Begin_DisabledNoRegistry(t *testing.T) { svc := newTestAuthServiceNoOAuth(t, newFakeRepo()) _, err := svc.BeginHostedOAuth(context.Background(), "google", - "https://identity.test/cb", "https://app.test/") + "https://identity.test/cb", "https://app.test/", "csrf-123") assert.True(t, errors.Is(err, ErrOAuthDisabled)) } @@ -84,7 +84,7 @@ func TestHostedOAuth_Complete_TamperedStateRejected(t *testing.T) { ctx := context.Background() begin, err := svc.BeginHostedOAuth(ctx, "google", - "https://identity.test/oauth/callback/google", "https://app.test/finish") + "https://identity.test/oauth/callback/google", "https://app.test/finish", "csrf-123") require.NoError(t, err) stateToken := stateTokenFromAuthURL(t, begin.AuthorizationURL) @@ -92,7 +92,24 @@ func TestHostedOAuth_Complete_TamperedStateRejected(t *testing.T) { tampered := stateToken[:len(stateToken)-2] + "AA" _, err = svc.CompleteHostedOAuth(ctx, "google", fakeOAuthCode("hosted@example.com", "Hosted", "", "google"), - tampered, "", "") + tampered, "", "", "", []string{"csrf-123"}) + require.Error(t, err) + assert.True(t, errors.Is(err, ErrUnauthenticated)) +} + +func TestHostedOAuth_Complete_CSRFMismatchRejected(t *testing.T) { + repo := newFakeRepo() + svc := newTestAuthService(t, repo) + ctx := context.Background() + + begin, err := svc.BeginHostedOAuth(ctx, "google", + "https://identity.test/oauth/callback/google", "https://app.test/finish", "csrf-123") + require.NoError(t, err) + stateToken := stateTokenFromAuthURL(t, begin.AuthorizationURL) + + _, err = svc.CompleteHostedOAuth(ctx, "google", + fakeOAuthCode("hosted@example.com", "Hosted", "", "google"), + stateToken, "", "", "", []string{"wrong-csrf-token"}) require.Error(t, err) assert.True(t, errors.Is(err, ErrUnauthenticated)) } @@ -103,14 +120,14 @@ func TestHostedOAuth_Complete_ProviderMismatchRejected(t *testing.T) { ctx := context.Background() begin, err := svc.BeginHostedOAuth(ctx, "google", - "https://identity.test/oauth/callback/google", "https://app.test/finish") + "https://identity.test/oauth/callback/google", "https://app.test/finish", "csrf-123") require.NoError(t, err) stateToken := stateTokenFromAuthURL(t, begin.AuthorizationURL) // The token was minted for google but the callback path says github. _, err = svc.CompleteHostedOAuth(ctx, "github", fakeOAuthCode("hosted@example.com", "Hosted", "", "google"), - stateToken, "", "") + stateToken, "", "", "", []string{"csrf-123"}) require.Error(t, err) assert.True(t, errors.Is(err, ErrUnauthenticated)) } diff --git a/internal/service/auth_oauth_linkage_test.go b/internal/service/auth_oauth_linkage_test.go index a41bc055..aa92e116 100644 --- a/internal/service/auth_oauth_linkage_test.go +++ b/internal/service/auth_oauth_linkage_test.go @@ -29,7 +29,7 @@ func TestOAuthLogin_ReturningUserViaProviderID(t *testing.T) { // fakeOAuthExchanger encodes ProviderUserID as "sub-", so we // drive a login whose sub deterministically matches the seeded link. code := fakeOAuthCode("stable-123", "Stable", "", "google") - res, err := svc.OAuthLogin(ctx, code, "google", "https://app/cb", "", "", "", "", "") + res, err := svc.OAuthLogin(ctx, OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.NoError(t, err) require.NotNil(t, res) assert.Equal(t, seed.ID, res.User.ID) @@ -51,7 +51,7 @@ func TestOAuthLogin_FirstTimeLinkToExistingUser(t *testing.T) { seed := seedUser(repo, "alice@example.com", "", "active") code := fakeOAuthCode("alice@example.com", "Alice", "", "google") - res, err := svc.OAuthLogin(ctx, code, "google", "https://app/cb", "", "", "", "", "") + res, err := svc.OAuthLogin(ctx, OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.NoError(t, err) assert.Equal(t, seed.ID, res.User.ID) @@ -71,7 +71,7 @@ func TestOAuthLogin_NewUserGetsIdentityLink(t *testing.T) { ctx := context.Background() code := fakeOAuthCode("brand-new@example.com", "Newbie", "", "microsoft") - res, err := svc.OAuthLogin(ctx, code, "microsoft", "https://app/cb", "", "", "", "", "") + res, err := svc.OAuthLogin(ctx, OAuthLoginParams{Code: code, Provider: "microsoft", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.NoError(t, err) require.NotNil(t, res) @@ -88,9 +88,8 @@ func TestOAuthLogin_LinkFailureDoesNotFailLogin(t *testing.T) { ctx := context.Background() repo.failCreateOAuthIdentity = true - res, err := svc.OAuthLogin(ctx, - fakeOAuthCode("link-failure@example.com", "Link Failure", "", "google"), - "google", "https://app/cb", "", "", "", "", "") + res, err := svc.OAuthLogin(ctx, OAuthLoginParams{Code: fakeOAuthCode("link-failure@example.com", "Link Failure", "", "google"), Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) + require.NoError(t, err) require.NotNil(t, res) assert.Equal(t, "link-failure@example.com", res.User.Email) @@ -117,7 +116,7 @@ func TestOAuthLogin_ProviderEmailChangedStaysLinked(t *testing.T) { // First login: creates user@old.com and links sub-stableid. first := fakeOAuthCode("stableid", "User", "", "google") - res1, err := svc.OAuthLogin(ctx, first, "google", "https://app/cb", "", "", "", "", "") + res1, err := svc.OAuthLogin(ctx, OAuthLoginParams{Code: first, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.NoError(t, err) originalID := res1.User.ID assert.Equal(t, "stableid", res1.User.Email) @@ -139,7 +138,7 @@ func TestOAuthLogin_ProviderEmailChangedStaysLinked(t *testing.T) { // Drive a second login with the SAME provider+sub: the fakeExchanger // happens to return the same sub for the same code, so this proves // the lookup returns the original user (no duplicate created). - res2, err := svc.OAuthLogin(ctx, first, "google", "https://app/cb", "", "", "", "", "") + res2, err := svc.OAuthLogin(ctx, OAuthLoginParams{Code: first, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.NoError(t, err) assert.Equal(t, originalID, res2.User.ID) @@ -174,7 +173,7 @@ func TestOAuthLogin_ProviderEmailChangeDoesNotMutateLocal(t *testing.T) { })) code := fakeOAuthCode("newaddr@example.com", "User", "", "google") - res, err := svc.OAuthLogin(ctx, code, "google", "https://app/cb", "", "", "", "", "") + res, err := svc.OAuthLogin(ctx, OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.NoError(t, err) assert.Equal(t, seed.ID, res.User.ID, "must resolve to original user via (provider, sub)") assert.Equal(t, "old@example.com", res.User.Email, "provider-side email change must NOT mutate local email") @@ -196,16 +195,14 @@ func TestOAuthLogin_CrossProviderLinking(t *testing.T) { ctx := context.Background() // Google login first. - res1, err := svc.OAuthLogin(ctx, - fakeOAuthCode("multi@example.com", "Multi", "", "google"), - "google", "https://app/cb", "", "", "", "", "") + res1, err := svc.OAuthLogin(ctx, OAuthLoginParams{Code: fakeOAuthCode("multi@example.com", "Multi", "", "google"), Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) + require.NoError(t, err) uid := res1.User.ID // Microsoft login with the same email — should land on the SAME user. - res2, err := svc.OAuthLogin(ctx, - fakeOAuthCode("multi@example.com", "Multi", "", "microsoft"), - "microsoft", "https://app/cb", "", "", "", "", "") + res2, err := svc.OAuthLogin(ctx, OAuthLoginParams{Code: fakeOAuthCode("multi@example.com", "Multi", "", "microsoft"), Provider: "microsoft", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) + require.NoError(t, err) assert.Equal(t, uid, res2.User.ID) diff --git a/internal/service/auth_oauth_test.go b/internal/service/auth_oauth_test.go index ea40a648..c9a73e15 100644 --- a/internal/service/auth_oauth_test.go +++ b/internal/service/auth_oauth_test.go @@ -17,7 +17,7 @@ type oauthExchangeOnly struct { calls int } -func (f *oauthExchangeOnly) Exchange(_ context.Context, _, _ string) (*oauth.Identity, error) { +func (f *oauthExchangeOnly) Exchange(_ context.Context, _ oauth.ExchangeParams) (*oauth.Identity, error) { f.calls++ if f.err != nil { return nil, f.err @@ -29,9 +29,8 @@ func TestOAuthLogin_Disabled_NoRegistry(t *testing.T) { repo := newFakeRepo() svc := newTestAuthServiceNoOAuth(t, repo) - _, err := svc.OAuthLogin(context.Background(), - fakeOAuthCode("u@example.com", "U", "", "google"), - "google", "https://app/cb", "", "", "", "", "") + _, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: fakeOAuthCode("u@example.com", "U", "", "google"), Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) + require.Error(t, err) assert.True(t, errors.Is(err, ErrOAuthDisabled)) } @@ -40,9 +39,8 @@ func TestOAuthLogin_UnknownProvider(t *testing.T) { repo := newFakeRepo() svc := newTestAuthService(t, repo) - _, err := svc.OAuthLogin(context.Background(), - fakeOAuthCode("u@example.com", "U", "", "yahoo"), - "yahoo", "https://app/cb", "", "", "", "", "") + _, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: fakeOAuthCode("u@example.com", "U", "", "yahoo"), Provider: "yahoo", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) + require.Error(t, err) assert.True(t, errors.Is(err, ErrInvalidArgument)) } @@ -95,7 +93,7 @@ func TestOAuthLogin_NewUserCreatedAndAudited(t *testing.T) { svc := newTestAuthService(t, repo) code := fakeOAuthCode("new-oauth@example.com", "Newcomer", "https://avatar/", "google") - res, err := svc.OAuthLogin(context.Background(), code, "google", "https://app/cb", "", "", "", "10.0.0.1", "TestAgent") + res, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "10.0.0.1", UserAgent: "TestAgent"}) require.NoError(t, err) require.NotNil(t, res) assert.Equal(t, "new-oauth@example.com", res.User.Email) @@ -117,7 +115,7 @@ func TestOAuthLogin_ExistingUserLooksUpByEmail(t *testing.T) { seed := seedUser(repo, "alice@example.com", "", "active") code := fakeOAuthCode("alice@example.com", "Alice Updated", "https://av/", "google") - res, err := svc.OAuthLogin(context.Background(), code, "google", "https://app/cb", "", "", "", "", "") + res, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.NoError(t, err) assert.Equal(t, seed.ID, res.User.ID) assert.Equal(t, "Alice Updated", res.User.Name) @@ -128,8 +126,8 @@ func TestOAuthLogin_ExchangerErrorPropagatesAsUnauthenticated(t *testing.T) { svc := newTestAuthService(t, repo) // "err|..." form makes the fake exchanger return ErrCodeExchangeFailed. - _, err := svc.OAuthLogin(context.Background(), - "err|something-bad", "google", "https://app/cb", "", "", "", "", "") + _, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: "err|something-bad", Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) + require.Error(t, err) assert.True(t, errors.Is(err, ErrUnauthenticated)) } @@ -138,8 +136,8 @@ func TestOAuthLogin_UnverifiedEmailRejected(t *testing.T) { repo := newFakeRepo() svc := newTestAuthService(t, repo) - _, err := svc.OAuthLogin(context.Background(), - "unverified|u@example.com", "google", "https://app/cb", "", "", "", "", "") + _, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: "unverified|u@example.com", Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) + require.Error(t, err) assert.True(t, errors.Is(err, ErrUnauthenticated)) } @@ -155,16 +153,8 @@ func TestOAuthLogin_StateMismatchRejected(t *testing.T) { require.NoError(t, err) _, err = svc.OAuthLogin( - context.Background(), - fakeOAuthCode("state-mismatch@example.com", "Mismatch", "", "google"), - "google", - "https://app/cb", - "", - begin.State+"-wrong", - begin.StateToken, - "", - "", - ) + context.Background(), OAuthLoginParams{Code: fakeOAuthCode("state-mismatch@example.com", "Mismatch", "", "google"), Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: begin.State + "-wrong", StateToken: begin.StateToken, AppleUserPayload: "", IPAddr: "", UserAgent: ""}) + require.Error(t, err) assert.True(t, errors.Is(err, ErrUnauthenticated)) assert.Zero(t, exchanger.calls.Load()) @@ -180,16 +170,8 @@ func TestOAuthLogin_StateTokenAllowsCallbackWithoutExplicitVerifier(t *testing.T begin, err := svc.BeginOAuthLogin(context.Background(), "google", "https://app/cb") require.NoError(t, err) res, err := svc.OAuthLogin( - context.Background(), - fakeOAuthCode("state-token@example.com", "State Token", "", "google"), - "google", - "https://app/cb", - "", - begin.State, - begin.StateToken, - "", - "", - ) + context.Background(), OAuthLoginParams{Code: fakeOAuthCode("state-token@example.com", "State Token", "", "google"), Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: begin.State, StateToken: begin.StateToken, AppleUserPayload: "", IPAddr: "", UserAgent: ""}) + require.NoError(t, err) require.NotNil(t, res) assert.Equal(t, "state-token@example.com", res.User.Email) @@ -205,7 +187,7 @@ func TestOAuthLogin_AuditEventsRecorded(t *testing.T) { svc := newTestAuthService(t, repo) code := fakeOAuthCode("audit@example.com", "Au", "", "github") - res, err := svc.OAuthLogin(context.Background(), code, "github", "https://app/cb", "", "", "", "1.2.3.4", "Mozilla/5.0") + res, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: code, Provider: "github", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "1.2.3.4", UserAgent: "Mozilla/5.0"}) require.NoError(t, err) assert.Equal(t, "audit@example.com", res.User.Email) } @@ -214,7 +196,7 @@ func TestOAuthLogin_EmptyProviderInvalid(t *testing.T) { repo := newFakeRepo() svc := newTestAuthService(t, repo) - _, err := svc.OAuthLogin(context.Background(), "code", "", "https://app/cb", "", "", "", "", "") + _, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: "code", Provider: "", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.Error(t, err) assert.True(t, errors.Is(err, ErrInvalidArgument)) } @@ -231,7 +213,7 @@ func TestOAuthLogin_ProviderReturnsNoEmailRejected(t *testing.T) { registry.Register("google", exchanger) svc := newTestAuthServiceWithRegistry(t, repo, registry) - _, err := svc.OAuthLogin(context.Background(), "code", "google", "https://app/cb", "", "", "", "", "") + _, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: "code", Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.Error(t, err) assert.True(t, errors.Is(err, ErrUnauthenticated)) assert.Equal(t, 1, exchanger.calls) @@ -253,7 +235,7 @@ func TestOAuthLogin_ExchangerInvoked(t *testing.T) { svc := newTestAuthServiceWithRegistry(t, repo, r) code := fakeOAuthCode("u@example.com", "U", "", "google") - if _, err := svc.OAuthLogin(context.Background(), code, "google", "https://app/cb", "", "", "", "", ""); err != nil { + if _, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}); err != nil { t.Fatalf("OAuthLogin: %v", err) } if exch.calls.Load() != 1 { diff --git a/internal/service/auth_test.go b/internal/service/auth_test.go index 541fe2f8..b16e66d2 100644 --- a/internal/service/auth_test.go +++ b/internal/service/auth_test.go @@ -432,7 +432,7 @@ func TestOAuthLogin_CreatesNewUser(t *testing.T) { svc := newTestAuthService(t, repo) code := fakeOAuthCode("oauth@example.com", "OAuth User", "https://img.example.com/pic.jpg", "google") - result, err := svc.OAuthLogin(context.Background(), code, "google", "https://app/cb", "", "", "", "", "") + result, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.NoError(t, err) assert.Equal(t, "oauth@example.com", result.User.Email) assert.Equal(t, "OAuth User", result.User.Name) @@ -445,7 +445,7 @@ func TestOAuthLogin_ExistingUserUpdatesProfile(t *testing.T) { seedUser(repo, "existing@example.com", "", "active") code := fakeOAuthCode("existing@example.com", "New Name", "https://pic.url", "google") - result, err := svc.OAuthLogin(context.Background(), code, "google", "https://app/cb", "", "", "", "", "") + result, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.NoError(t, err) assert.Equal(t, "existing@example.com", result.User.Email) } diff --git a/internal/service/error_paths_test.go b/internal/service/error_paths_test.go index 61b84696..9e3df3b9 100644 --- a/internal/service/error_paths_test.go +++ b/internal/service/error_paths_test.go @@ -82,7 +82,7 @@ func TestOAuthLogin_FindUserErrors(t *testing.T) { svc := newTestAuthServiceErr(t, r) code := fakeOAuthCode("x@example.com", "X", "", "google") - _, err := svc.OAuthLogin(context.Background(), code, "google", "https://app/cb", "", "", "", "", "") + _, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.Error(t, err) } @@ -92,7 +92,7 @@ func TestOAuthLogin_CreateUserErrors(t *testing.T) { svc := newTestAuthServiceErr(t, r) code := fakeOAuthCode("x@example.com", "X", "", "google") - _, err := svc.OAuthLogin(context.Background(), code, "google", "https://app/cb", "", "", "", "", "") + _, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.Error(t, err) } @@ -102,7 +102,7 @@ func TestOAuthLogin_IssueTokensFails(t *testing.T) { svc := newTestAuthServiceErr(t, r) code := fakeOAuthCode("x@example.com", "X", "", "google") - _, err := svc.OAuthLogin(context.Background(), code, "google", "https://app/cb", "", "", "", "", "") + _, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.Error(t, err) } @@ -574,7 +574,7 @@ func TestOAuthLogin_ExistingUserUpdateWarns(t *testing.T) { // Should still succeed because the update failure is logged but not propagated. code := fakeOAuthCode("ouw@example.com", "Different Name", "https://avatar.png", "google") - res, err := svc.OAuthLogin(context.Background(), code, "google", "https://app/cb", "", "", "", "", "") + res, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.NoError(t, err) assert.NotNil(t, res) } diff --git a/internal/service/extended_test.go b/internal/service/extended_test.go index 79c21c74..30d8b201 100644 --- a/internal/service/extended_test.go +++ b/internal/service/extended_test.go @@ -115,7 +115,7 @@ func TestOAuthLogin_EmptyCodeFails(t *testing.T) { repo := newFakeRepo() svc := newTestAuthService(t, repo) - _, err := svc.OAuthLogin(context.Background(), "", "google", "https://app/cb", "", "", "", "", "") + _, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: "", Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.Error(t, err) assert.True(t, errors.Is(err, ErrInvalidArgument)) } @@ -125,7 +125,7 @@ func TestOAuthLogin_DefaultsDisplayNameFromEmail(t *testing.T) { svc := newTestAuthService(t, repo) code := fakeOAuthCode("carol@example.com", "", "", "google") - result, err := svc.OAuthLogin(context.Background(), code, "google", "https://app/cb", "", "", "", "", "") + result, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.NoError(t, err) assert.Equal(t, "carol", result.User.Name) } @@ -136,7 +136,7 @@ func TestOAuthLogin_DeactivatedAccountFails(t *testing.T) { seedUser(repo, "deac-oauth@example.com", "", "deactivated") code := fakeOAuthCode("deac-oauth@example.com", "X", "", "google") - _, err := svc.OAuthLogin(context.Background(), code, "google", "https://app/cb", "", "", "", "", "") + _, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.Error(t, err) assert.True(t, errors.Is(err, ErrAccountNotActive)) } @@ -148,7 +148,7 @@ func TestOAuthLogin_ExistingUserNoNameChangeNoUpdate(t *testing.T) { // Same display name and no avatar -- no patch should happen, but call must succeed. code := fakeOAuthCode("noupd@example.com", "noupd", "", "google") - _, err := svc.OAuthLogin(context.Background(), code, "google", "https://app/cb", "", "", "", "", "") + _, err := svc.OAuthLogin(context.Background(), OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.NoError(t, err) } diff --git a/internal/service/login_policy_enforce_test.go b/internal/service/login_policy_enforce_test.go index 23794ee4..7e455f4a 100644 --- a/internal/service/login_policy_enforce_test.go +++ b/internal/service/login_policy_enforce_test.go @@ -478,7 +478,7 @@ func TestOAuthLogin_DeniedByPolicy(t *testing.T) { svc.WithLoginGovernance(claimedPasswordOnlyGovernance()) code := fakeOAuthCode("alice@acme.com", "Alice", "", "google") - _, err := svc.OAuthLogin(withProject("proj-1"), code, "google", "https://app/cb", "", "", "", "", "") + _, err := svc.OAuthLogin(withProject("proj-1"), OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.ErrorIs(t, err, ErrPermissionDenied, "oauth must be denied for a password-only tenant") } @@ -490,7 +490,7 @@ func TestOAuthLogin_AllowedByPolicy(t *testing.T) { svc.WithLoginGovernance(withAllowedMethods(LoginMethodOAuth)) code := fakeOAuthCode("alice@acme.com", "Alice", "", "google") - res, err := svc.OAuthLogin(withProject("proj-1"), code, "google", "https://app/cb", "", "", "", "", "") + res, err := svc.OAuthLogin(withProject("proj-1"), OAuthLoginParams{Code: code, Provider: "google", RedirectURI: "https://app/cb", CodeVerifier: "", State: "", StateToken: "", AppleUserPayload: "", IPAddr: "", UserAgent: ""}) require.NoError(t, err) require.NotEmpty(t, res.AccessToken) } @@ -503,11 +503,11 @@ func TestRedeemOAuthCode_DeniedByPolicy(t *testing.T) { ctx := withProject("proj-1") begin, err := svc.BeginHostedOAuth(ctx, "google", - "https://identity.test/oauth/callback/google", "https://app.test/finish") + "https://identity.test/oauth/callback/google", "https://app.test/finish", "csrf-123") require.NoError(t, err) cb, err := svc.CompleteHostedOAuth(ctx, "google", fakeOAuthCode("alice@acme.com", "Alice", "", "google"), - stateTokenFromAuthURL(t, begin.AuthorizationURL), "1.2.3.4", "test-agent") + stateTokenFromAuthURL(t, begin.AuthorizationURL), "", "1.2.3.4", "test-agent", []string{"csrf-123"}) require.NoError(t, err) // Policy is consulted at redeem, the point tokens would be issued. diff --git a/internal/service/return_allowlist.go b/internal/service/return_allowlist.go index 41ef605d..cfa97c4b 100644 --- a/internal/service/return_allowlist.go +++ b/internal/service/return_allowlist.go @@ -1,14 +1,18 @@ package service -import "strings" +import ( + "net/url" + "path" + "strings" +) // ReturnAllowlist is the fail-closed validator for an app return_to URL. // It is built from GATEWAY_OAUTH_ALLOWED_RETURN_URLS — a comma-separated -// list of exact origins or URL prefixes — and shared by the hosted OAuth +// list of exact origins or path prefixes — and shared by the hosted OAuth // flow (where the HTTP handler checks return_to at /oauth/start) and the // passwordless magic-link flow (where RequestMagicLink checks the -// requested return_to). A return_to is allowed when it equals an entry or -// begins with one; everything else is rejected. +// requested return_to). A return_to must have the configured origin and, +// for path entries, match the configured path or one of its descendants. // // An empty allowlist disables both flows that depend on it: Enabled() // reports false and Allows() rejects everything. @@ -34,21 +38,57 @@ func (a ReturnAllowlist) Enabled() bool { return len(a.entries) > 0 } // Entries returns the configured allowlist entries (for startup logging). func (a ReturnAllowlist) Entries() []string { return a.entries } -// Allows reports whether returnTo is permitted. A returnTo matches when -// it equals an allowlist entry or begins with one (prefix match), so a -// deployer can allow an entire app origin or pin a specific callback -// path. The match is exact-byte; no normalization is applied because a -// normalized-but-mismatched URL is exactly the open-redirect case the -// allowlist exists to close. +// Allows reports whether returnTo is permitted by an exact origin or a +// path-bound prefix. Allowlist entries may not contain a query or fragment. func (a ReturnAllowlist) Allows(returnTo string) bool { - returnTo = strings.TrimSpace(returnTo) - if returnTo == "" { + returnURL, ok := parseReturnURL(returnTo) + if !ok { return false } for _, e := range a.entries { - if returnTo == e || strings.HasPrefix(returnTo, e) { + entryURL, ok := parseReturnURL(e) + if !ok || entryURL.RawQuery != "" || entryURL.Fragment != "" { + continue + } + if sameReturnOrigin(entryURL, returnURL) && returnPathAllowed(entryURL, returnURL) { return true } } return false } + +func parseReturnURL(raw string) (*url.URL, bool) { + u, err := url.Parse(strings.TrimSpace(raw)) + if err != nil || u.Scheme == "" || u.Host == "" || u.User != nil || u.Hostname() == "" { + return nil, false + } + if !strings.EqualFold(u.Scheme, "https") && !strings.EqualFold(u.Scheme, "http") { + return nil, false + } + return u, true +} + +func sameReturnOrigin(a, b *url.URL) bool { + return strings.EqualFold(a.Scheme, b.Scheme) && + strings.EqualFold(a.Hostname(), b.Hostname()) && + returnPort(a) == returnPort(b) +} + +func returnPort(u *url.URL) string { + if port := u.Port(); port != "" { + return port + } + if strings.EqualFold(u.Scheme, "https") { + return "443" + } + return "80" +} + +func returnPathAllowed(entry, returnTo *url.URL) bool { + if entry.Path == "" || entry.Path == "/" { + return true + } + prefix := strings.TrimSuffix(path.Clean(entry.Path), "/") + returnPath := path.Clean(returnTo.Path) + return returnPath == prefix || strings.HasPrefix(returnPath, prefix+"/") +} diff --git a/internal/service/return_allowlist_test.go b/internal/service/return_allowlist_test.go index 24e8ff0d..aeee36d0 100644 --- a/internal/service/return_allowlist_test.go +++ b/internal/service/return_allowlist_test.go @@ -32,18 +32,22 @@ func TestParseReturnAllowlist(t *testing.T) { func TestReturnAllowlist_Allows(t *testing.T) { t.Parallel() - a := ParseReturnAllowlist("https://app.example.com/,https://other.example.org/auth") + a := ParseReturnAllowlist("https://app.example.com,https://other.example.org/auth") tests := []struct { returnTo string want bool }{ {"https://app.example.com/", true}, - {"https://app.example.com/auth/finish?next=/home", true}, // prefix match + {"https://app.example.com/auth/finish?next=/home", true}, {"https://other.example.org/auth", true}, {"https://other.example.org/auth/callback", true}, + {"https://other.example.org/authentication", false}, + {"https://other.example.org/auth/../unrelated", false}, {"https://evil.example.net/", false}, - {"https://app.example.com.evil.net/", false}, // not a prefix of an entry - {"http://app.example.com/", false}, // scheme differs + {"https://app.example.com.evil.net/", false}, + {"https://app.example.com.attacker.tld/", false}, + {"https://app.example.com@evil.example.net/", false}, + {"http://app.example.com/", false}, {"", false}, {" ", false}, } @@ -54,6 +58,20 @@ func TestReturnAllowlist_Allows(t *testing.T) { } } +func TestReturnAllowlist_RejectsEntriesWithQueryOrFragment(t *testing.T) { + t.Parallel() + + for _, entry := range []string{ + "https://app.example.com/callback?source=oauth", + "https://app.example.com/callback#fragment", + } { + a := ParseReturnAllowlist(entry) + if a.Allows("https://app.example.com/callback") { + t.Fatalf("entry %q allowed a return_to", entry) + } + } +} + func TestReturnAllowlist_EmptyDeniesAll(t *testing.T) { t.Parallel() a := ParseReturnAllowlist("") diff --git a/internal/service/testutil_test.go b/internal/service/testutil_test.go index d4f7cec5..a769fad0 100644 --- a/internal/service/testutil_test.go +++ b/internal/service/testutil_test.go @@ -1028,9 +1028,9 @@ type fakeOAuthExchanger struct { calls atomic.Int32 } -func (f *fakeOAuthExchanger) Exchange(_ context.Context, code, _ string) (*oauth.Identity, error) { +func (f *fakeOAuthExchanger) Exchange(_ context.Context, params oauth.ExchangeParams) (*oauth.Identity, error) { f.calls.Add(1) - parts := splitCode(code) + parts := splitCode(params.Code) switch parts[0] { case "ok": return &oauth.Identity{ @@ -1078,10 +1078,7 @@ func splitCode(code string) []string { return out } -// defaultTestOAuthRegistry returns a registry pre-populated with a -// fakeOAuthExchanger for "google", "microsoft", and "github" so the -// existing test suite continues to work after the OAuthLogin signature -// change. +// defaultTestOAuthRegistry returns the providers exercised by service tests. func defaultTestOAuthRegistry() *oauth.Registry { r := oauth.NewRegistry() for _, p := range []string{"google", "microsoft", "github"} { diff --git a/pkg/oauth/apple.go b/pkg/oauth/apple.go new file mode 100644 index 00000000..8cf599eb --- /dev/null +++ b/pkg/oauth/apple.go @@ -0,0 +1,320 @@ +package oauth + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" +) + +const ( + appleAuthorizationURL = "https://appleid.apple.com/auth/authorize" + appleTokenURL = "https://appleid.apple.com/auth/token" //nolint:gosec // this is a public URL, not a credential + appleJWKSURL = "https://appleid.apple.com/auth/keys" + appleIssuer = "https://appleid.apple.com" +) + +// AppleConfig configures an Apple Exchanger. ClientID, TeamID, +// KeyID, and PrivateKey are required. +type AppleConfig struct { + ClientID string + TeamID string + KeyID string + PrivateKey string + + TokenURL string + JWKSURL string + Issuer string + + HTTPClient *http.Client + JWKSCacheTTL time.Duration + Now func() time.Time +} + +type appleExchanger struct { + cfg AppleConfig + client *http.Client + jwks *jwksCache + + mu sync.Mutex + parsedKey jwk.Key + cachedToken string + tokenExp time.Time +} + +// NewApple returns an Exchanger that implements Sign-In with Apple. +func NewApple(cfg AppleConfig) Exchanger { + if cfg.HTTPClient == nil { + cfg.HTTPClient = defaultHTTPClient() + } + if cfg.JWKSCacheTTL == 0 { + cfg.JWKSCacheTTL = time.Hour + } + if cfg.Now == nil { + cfg.Now = time.Now + } + if cfg.TokenURL == "" { + cfg.TokenURL = appleTokenURL + } + if cfg.JWKSURL == "" { + cfg.JWKSURL = appleJWKSURL + } + if cfg.Issuer == "" { + cfg.Issuer = appleIssuer + } + return &appleExchanger{ + cfg: cfg, + client: cfg.HTTPClient, + jwks: newJWKSCache(cfg.JWKSURL, cfg.JWKSCacheTTL, cfg.HTTPClient), + } +} + +func (a *appleExchanger) clientSecret() (string, error) { + a.mu.Lock() + defer a.mu.Unlock() + + now := a.cfg.Now() + if a.cachedToken != "" && now.Before(a.tokenExp) { + return a.cachedToken, nil + } + + if a.parsedKey == nil { + pkBytes := []byte(strings.TrimSpace(a.cfg.PrivateKey)) + if !strings.HasPrefix(string(pkBytes), "-----BEGIN") { + // Try base64 decoding if it doesn't look like PEM + dec, err := base64.StdEncoding.DecodeString(string(pkBytes)) + if err == nil { + pkBytes = dec + } + } + + key, err := jwk.ParseKey(pkBytes, jwk.WithPEM(true)) + if err != nil { + return "", fmt.Errorf("parse private key: %w", err) + } + if err := key.Set(jwk.KeyIDKey, a.cfg.KeyID); err != nil { + return "", err + } + a.parsedKey = key + } + + exp := now.Add(5 * time.Minute) + tok, err := jwt.NewBuilder(). + Issuer(a.cfg.TeamID). + IssuedAt(now). + Expiration(exp). + Audience([]string{a.cfg.Issuer}). + Subject(a.cfg.ClientID). + Build() + if err != nil { + return "", fmt.Errorf("build client secret: %w", err) + } + + signed, err := jwt.Sign(tok, jwt.WithKey(jwa.ES256, a.parsedKey)) + if err != nil { + return "", fmt.Errorf("sign client secret: %w", err) + } + + a.cachedToken = string(signed) + a.tokenExp = exp.Add(-30 * time.Second) // 30s buffer + return a.cachedToken, nil +} + +func (a *appleExchanger) AuthorizationURL(ctx context.Context, redirectURI, state, codeChallenge string) (string, error) { + if a.cfg.ClientID == "" { + return "", fmt.Errorf("%w: client credentials not configured", ErrCodeExchangeFailed) + } + + params := url.Values{} + params.Set("client_id", a.cfg.ClientID) + params.Set("redirect_uri", redirectURI) + params.Set("response_type", "code") + params.Set("response_mode", "form_post") + params.Set("scope", "name email") + if err := addPKCEParams(params, state, codeChallenge); err != nil { + return "", err + } + return buildAuthorizationURL(appleAuthorizationURL, params) +} + +func (a *appleExchanger) Exchange(ctx context.Context, params ExchangeParams) (*Identity, error) { + if a.cfg.ClientID == "" || a.cfg.PrivateKey == "" { + return nil, fmt.Errorf("%w: client credentials not configured", ErrCodeExchangeFailed) + } + + secret, err := a.clientSecret() + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrCodeExchangeFailed, err) + } + + form := url.Values{} + form.Set("code", params.Code) + form.Set("client_id", a.cfg.ClientID) + form.Set("client_secret", secret) + form.Set("redirect_uri", params.RedirectURI) + form.Set("grant_type", "authorization_code") + if params.CodeVerifier != "" { + form.Set("code_verifier", params.CodeVerifier) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.cfg.TokenURL, strings.NewReader(form.Encode())) + if err != nil { + return nil, fmt.Errorf("%w: build request: %w", ErrCodeExchangeFailed, err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := a.client.Do(req) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrCodeExchangeFailed, err) + } + defer func() { _ = resp.Body.Close() }() + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("%w: read body: %w", ErrCodeExchangeFailed, err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%w: provider HTTP %d", ErrCodeExchangeFailed, resp.StatusCode) + } + + var tr struct { + IDToken string `json:"id_token"` + Error string `json:"error"` + } + if err := json.Unmarshal(body, &tr); err != nil { + return nil, fmt.Errorf("%w: parse response: %w", ErrCodeExchangeFailed, err) + } + if tr.Error != "" { + return nil, fmt.Errorf("%w: %s", ErrCodeExchangeFailed, tr.Error) + } + if tr.IDToken == "" { + return nil, fmt.Errorf("%w: provider returned no id_token", ErrCodeExchangeFailed) + } + + claims, err := a.verifyIDToken(ctx, tr.IDToken) + if err != nil { + return nil, err + } + + email := strings.ToLower(strings.TrimSpace(claims.Email)) + if email == "" { + return nil, fmt.Errorf("%w: missing email", ErrIdentityVerification) + } + + verified := false + switch v := claims.EmailVerified.(type) { + case bool: + verified = v + case string: + verified = v == "true" + } + if !verified { + return nil, fmt.Errorf("%w: email not verified", ErrIdentityVerification) + } + + name := claims.Name + if params.AppleUserPayload != "" { + var appleUser struct { + Name struct { + FirstName string `json:"firstName"` + LastName string `json:"lastName"` + } `json:"name"` + } + if err := json.Unmarshal([]byte(params.AppleUserPayload), &appleUser); err == nil { + first := strings.TrimSpace(appleUser.Name.FirstName) + last := strings.TrimSpace(appleUser.Name.LastName) + if first != "" || last != "" { + name = strings.TrimSpace(first + " " + last) + } + } + } + + return &Identity{ + ProviderUserID: claims.Sub, + Email: email, + EmailVerified: true, + Name: name, + AvatarURL: "", + Provider: "apple", + }, nil +} + +type appleIDClaims struct { + Sub string `json:"sub"` + Email string `json:"email"` + EmailVerified interface{} `json:"email_verified"` + Name string `json:"name"` +} + +func (a *appleExchanger) verifyIDToken(ctx context.Context, raw string) (*appleIDClaims, error) { + set, err := a.jwks.Get(ctx) + if err != nil { + return nil, fmt.Errorf("%w: jwks: %w", ErrIdentityVerification, err) + } + + payload, err := verifyJWS(raw, set) + if err != nil && errors.Is(err, errKeyNotFound) { + a.jwks.Invalidate() + set2, fErr := a.jwks.Get(ctx) + if fErr != nil { + return nil, fmt.Errorf("%w: %w", ErrIdentityVerification, err) + } + payload, err = verifyJWS(raw, set2) + } + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrIdentityVerification, err) + } + + tok, err := jwt.Parse(payload, jwt.WithVerify(false), jwt.WithValidate(false)) + if err != nil { + return nil, fmt.Errorf("%w: parse claims: %w", ErrIdentityVerification, err) + } + if iss := tok.Issuer(); iss != a.cfg.Issuer { + return nil, fmt.Errorf("%w: bad iss: %s", ErrIdentityVerification, iss) + } + auds := tok.Audience() + if !containsString(auds, a.cfg.ClientID) { + return nil, fmt.Errorf("%w: bad aud", ErrIdentityVerification) + } + now := a.cfg.Now() + if exp := tok.Expiration(); !exp.IsZero() && now.After(exp) { + return nil, fmt.Errorf("%w: token expired", ErrIdentityVerification) + } + if iat := tok.IssuedAt(); !iat.IsZero() && iat.After(now.Add(2*time.Minute)) { + return nil, fmt.Errorf("%w: iat in the future", ErrIdentityVerification) + } + + var claims appleIDClaims + if err := json.Unmarshal(payload, &claims); err != nil { + return nil, fmt.Errorf("%w: decode claims: %w", ErrIdentityVerification, err) + } + if claims.Sub == "" { + return nil, fmt.Errorf("%w: missing sub", ErrIdentityVerification) + } + if claims.Email == "" { + return nil, fmt.Errorf("%w: missing email", ErrIdentityVerification) + } + + var verified bool + if ev, ok := claims.EmailVerified.(string); ok && ev == "true" { + verified = true + } else if ev, ok := claims.EmailVerified.(bool); ok && ev { + verified = true + } + if !verified { + return nil, fmt.Errorf("%w: email not verified: %s", ErrEmailNotVerified, claims.Email) + } + + return &claims, nil +} diff --git a/pkg/oauth/apple_test.go b/pkg/oauth/apple_test.go new file mode 100644 index 00000000..4a7eeb43 --- /dev/null +++ b/pkg/oauth/apple_test.go @@ -0,0 +1,498 @@ +package oauth + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/pem" + "errors" + "net/http" + "strings" + "testing" + "time" +) + +func generateTestECDSAKey(tb testing.TB) string { + tb.Helper() + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + tb.Fatalf("ecdsa generate: %v", err) + } + + x509Encoded, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + tb.Fatalf("failed to marshal ECDSA key: %v", err) + } + + pemEncoded := pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: x509Encoded, + }) + + return string(pemEncoded) +} + +func TestApple_ExchangeSuccess(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + key := newTestKey(t, "kid-A") + fp := newFakeProvider(t) + + const clientID = "apple-client-id" + privateKeyPEM := generateTestECDSAKey(t) + + idToken := key.signIDToken(t, map[string]any{ + "iss": "https://appleid.apple.com", + "sub": "apple-sub-1", + "aud": clientID, + "iat": now.Unix(), + "exp": now.Add(5 * time.Minute).Unix(), + "email": "user@example.com", + "email_verified": true, // standard boolean + }) + + fp.tokenHandler = jsonHandler(map[string]any{ + "id_token": idToken, + "access_token": "discardable", + "token_type": "Bearer", + "expires_in": 3600, + }) + fp.jwksHandler = rawHandler(http.StatusOK, "application/json", key.JWKJSON) + + exch := NewApple(AppleConfig{ + ClientID: clientID, + TeamID: "team-123", + KeyID: "key-123", + PrivateKey: privateKeyPEM, + TokenURL: fp.URL("/token"), + JWKSURL: fp.URL("/jwks"), + Issuer: "https://appleid.apple.com", + Now: nowFunc(now), + }) + + id, err := exch.Exchange(context.Background(), ExchangeParams{Code: "the-code", RedirectURI: "https://app/cb"}) + if err != nil { + t.Fatalf("Exchange: %v", err) + } + if id.Email != "user@example.com" { + t.Errorf("email = %q", id.Email) + } + if id.ProviderUserID != "apple-sub-1" { + t.Errorf("provider id = %q", id.ProviderUserID) + } + if !id.EmailVerified { + t.Error("email_verified must be true") + } + if id.Provider != "apple" { + t.Errorf("provider = %q", id.Provider) + } +} + +func TestApple_EmailVerifiedString(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + key := newTestKey(t, "kid-A") + fp := newFakeProvider(t) + + const clientID = "apple-client-id" + privateKeyPEM := generateTestECDSAKey(t) + + idToken := key.signIDToken(t, map[string]any{ + "iss": "https://appleid.apple.com", + "sub": "apple-sub-1", + "aud": clientID, + "iat": now.Unix(), + "exp": now.Add(5 * time.Minute).Unix(), + "email": "user@example.com", + "email_verified": "true", // Apple string format + }) + + fp.tokenHandler = jsonHandler(map[string]any{"id_token": idToken}) + fp.jwksHandler = rawHandler(http.StatusOK, "application/json", key.JWKJSON) + + exch := NewApple(AppleConfig{ + ClientID: clientID, + TeamID: "team-123", + KeyID: "key-123", + PrivateKey: privateKeyPEM, + TokenURL: fp.URL("/token"), + JWKSURL: fp.URL("/jwks"), + Issuer: "https://appleid.apple.com", + Now: nowFunc(now), + }) + + id, err := exch.Exchange(context.Background(), ExchangeParams{Code: "the-code", RedirectURI: "https://app/cb"}) + if err != nil { + t.Fatalf("Exchange: %v", err) + } + if !id.EmailVerified { + t.Error("email_verified must be true") + } +} + +func TestApple_EmailNotVerifiedString(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + key := newTestKey(t, "kid-A") + fp := newFakeProvider(t) + + const clientID = "apple-client-id" + privateKeyPEM := generateTestECDSAKey(t) + + idToken := key.signIDToken(t, map[string]any{ + "iss": "https://appleid.apple.com", + "sub": "apple-sub-1", + "aud": clientID, + "iat": now.Unix(), + "exp": now.Add(5 * time.Minute).Unix(), + "email": "user@example.com", + "email_verified": "false", // Apple string format + }) + + fp.tokenHandler = jsonHandler(map[string]any{"id_token": idToken}) + fp.jwksHandler = rawHandler(http.StatusOK, "application/json", key.JWKJSON) + + exch := NewApple(AppleConfig{ + ClientID: clientID, + TeamID: "team-123", + KeyID: "key-123", + PrivateKey: privateKeyPEM, + TokenURL: fp.URL("/token"), + JWKSURL: fp.URL("/jwks"), + Issuer: "https://appleid.apple.com", + Now: nowFunc(now), + }) + + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "the-code", RedirectURI: "https://app/cb"}) + if err == nil || !errors.Is(err, ErrEmailNotVerified) { + t.Fatalf("want ErrEmailNotVerified, got %v", err) + } +} + +func TestApple_MissingEmail(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + key := newTestKey(t, "kid-A") + fp := newFakeProvider(t) + + const clientID = "apple-client-id" + privateKeyPEM := generateTestECDSAKey(t) + + idToken := key.signIDToken(t, map[string]any{ + "iss": "https://appleid.apple.com", + "sub": "apple-sub-1", + "aud": clientID, + "iat": now.Unix(), + "exp": now.Add(5 * time.Minute).Unix(), + // no email or email_verified + }) + + fp.tokenHandler = jsonHandler(map[string]any{"id_token": idToken}) + fp.jwksHandler = rawHandler(http.StatusOK, "application/json", key.JWKJSON) + + exch := NewApple(AppleConfig{ + ClientID: clientID, + TeamID: "team-123", + KeyID: "key-123", + PrivateKey: privateKeyPEM, + TokenURL: fp.URL("/token"), + JWKSURL: fp.URL("/jwks"), + Issuer: "https://appleid.apple.com", + Now: nowFunc(now), + }) + + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "the-code", RedirectURI: "https://app/cb"}) + if err == nil || !errors.Is(err, ErrIdentityVerification) { + t.Fatalf("want ErrIdentityVerification, got %v", err) + } +} + +func TestApple_TokenEndpointError(t *testing.T) { + t.Parallel() + fp := newFakeProvider(t) + fp.tokenHandler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant"}`)) + } + exch := NewApple(AppleConfig{ + ClientID: "x", + TeamID: "team", + KeyID: "key", + PrivateKey: generateTestECDSAKey(t), + TokenURL: fp.URL("/token"), + JWKSURL: fp.URL("/jwks"), + }) + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "bad", RedirectURI: "https://x"}) + if err == nil || !errors.Is(err, ErrCodeExchangeFailed) { + t.Fatalf("want ErrCodeExchangeFailed, got %v", err) + } +} + +func TestApple_NetworkError(t *testing.T) { + t.Parallel() + exch := NewApple(AppleConfig{ //nolint:gosec // dummy config + ClientID: "x", + TeamID: "team", + KeyID: "key", + PrivateKey: generateTestECDSAKey(t), //nolint:gosec // this is a dummy test key + TokenURL: "http://127.0.0.1:1/", // closed port + JWKSURL: "http://127.0.0.1:1/", + HTTPClient: &http.Client{Timeout: 50 * time.Millisecond}, + }) + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) + if err == nil || !errors.Is(err, ErrCodeExchangeFailed) { + t.Fatalf("want ErrCodeExchangeFailed, got %v", err) + } +} + +func TestApple_BadPrivateKey(t *testing.T) { + t.Parallel() + exch := NewApple(AppleConfig{ + ClientID: "x", + TeamID: "team", + KeyID: "key", + PrivateKey: "invalid-pem", // will fail to sign secret + }) + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) + if err == nil { + t.Fatalf("expected error due to invalid private key") + } +} + +func TestApple_BadSignature(t *testing.T) { + t.Parallel() + now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + + signing := newTestKey(t, "kid-attacker") + servingJWKS := newTestKey(t, "kid-server") // a different keypair + + idToken := signing.signIDToken(t, map[string]any{ + "iss": "https://appleid.apple.com", + "sub": "victim", + "aud": "client-id", + "iat": now.Unix(), + "exp": now.Add(5 * time.Minute).Unix(), + "email": "v@example.com", + "email_verified": true, + }) + + fp := newFakeProvider(t) + fp.tokenHandler = jsonHandler(map[string]any{"id_token": idToken}) + fp.jwksHandler = rawHandler(http.StatusOK, "application/json", servingJWKS.JWKJSON) + + exch := NewApple(AppleConfig{ + ClientID: "client-id", + TeamID: "team", + KeyID: "key", + PrivateKey: generateTestECDSAKey(t), + TokenURL: fp.URL("/token"), + JWKSURL: fp.URL("/jwks"), + Issuer: "https://appleid.apple.com", + Now: nowFunc(now), + }) + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) + if err == nil || !errors.Is(err, ErrIdentityVerification) { + t.Fatalf("want ErrIdentityVerification, got %v", err) + } +} + +func TestApple_BadIssuer(t *testing.T) { + t.Parallel() + now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + key := newTestKey(t, "kid-A") + fp := newFakeProvider(t) + idToken := key.signIDToken(t, map[string]any{ + "iss": "https://evil.example", + "sub": "x", + "aud": "client-id", + "iat": now.Unix(), + "exp": now.Add(5 * time.Minute).Unix(), + "email": "v@example.com", + "email_verified": true, + }) + fp.tokenHandler = jsonHandler(map[string]any{"id_token": idToken}) + fp.jwksHandler = rawHandler(http.StatusOK, "application/json", key.JWKJSON) + + exch := NewApple(AppleConfig{ + ClientID: "client-id", + TeamID: "team", + KeyID: "key", + PrivateKey: generateTestECDSAKey(t), + TokenURL: fp.URL("/token"), + JWKSURL: fp.URL("/jwks"), + Issuer: "https://appleid.apple.com", + Now: nowFunc(now), + }) + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) + if err == nil || !errors.Is(err, ErrIdentityVerification) { + t.Fatalf("want ErrIdentityVerification, got %v", err) + } +} + +func TestApple_BadAudience(t *testing.T) { + t.Parallel() + now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + key := newTestKey(t, "kid-A") + fp := newFakeProvider(t) + idToken := key.signIDToken(t, map[string]any{ + "iss": "https://appleid.apple.com", + "sub": "x", + "aud": "different-audience", + "iat": now.Unix(), + "exp": now.Add(5 * time.Minute).Unix(), + "email": "v@example.com", + "email_verified": true, + }) + fp.tokenHandler = jsonHandler(map[string]any{"id_token": idToken}) + fp.jwksHandler = rawHandler(http.StatusOK, "application/json", key.JWKJSON) + + exch := NewApple(AppleConfig{ + ClientID: "client-id", + TeamID: "team", + KeyID: "key", + PrivateKey: generateTestECDSAKey(t), + TokenURL: fp.URL("/token"), + JWKSURL: fp.URL("/jwks"), + Issuer: "https://appleid.apple.com", + Now: nowFunc(now), + }) + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) + if err == nil || !errors.Is(err, ErrIdentityVerification) { + t.Fatalf("want ErrIdentityVerification, got %v", err) + } +} + +func TestApple_ExpiredToken(t *testing.T) { + t.Parallel() + now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + key := newTestKey(t, "kid-A") + fp := newFakeProvider(t) + idToken := key.signIDToken(t, map[string]any{ + "iss": "https://appleid.apple.com", + "sub": "x", + "aud": "client-id", + "iat": now.Add(-2 * time.Hour).Unix(), + "exp": now.Add(-1 * time.Hour).Unix(), + "email": "v@example.com", + "email_verified": true, + }) + fp.tokenHandler = jsonHandler(map[string]any{"id_token": idToken}) + fp.jwksHandler = rawHandler(http.StatusOK, "application/json", key.JWKJSON) + + exch := NewApple(AppleConfig{ + ClientID: "client-id", + TeamID: "team", + KeyID: "key", + PrivateKey: generateTestECDSAKey(t), + TokenURL: fp.URL("/token"), + JWKSURL: fp.URL("/jwks"), + Issuer: "https://appleid.apple.com", + Now: nowFunc(now), + }) + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) + if err == nil || !errors.Is(err, ErrIdentityVerification) { + t.Fatalf("want ErrIdentityVerification, got %v", err) + } + if !strings.Contains(err.Error(), "expired") { + t.Logf("unexpected error message: %v", err) + } +} + +func TestApple_JWKSCaching(t *testing.T) { + t.Parallel() + now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + key := newTestKey(t, "kid-A") + fp := newFakeProvider(t) + fp.jwksHandler = rawHandler(http.StatusOK, "application/json", key.JWKJSON) + + const clientID = "client-id" + makeToken := func() string { + return key.signIDToken(t, map[string]any{ + "iss": "https://appleid.apple.com", + "sub": "u", + "aud": clientID, + "iat": now.Unix(), + "exp": now.Add(5 * time.Minute).Unix(), + "email": "u@example.com", + "email_verified": true, + }) + } + fp.tokenHandler = func(w http.ResponseWriter, r *http.Request) { + jsonHandler(map[string]any{"id_token": makeToken()})(w, r) + } + + exch := NewApple(AppleConfig{ + ClientID: clientID, + TeamID: "team", + KeyID: "key", + PrivateKey: generateTestECDSAKey(t), + TokenURL: fp.URL("/token"), + JWKSURL: fp.URL("/jwks"), + Issuer: "https://appleid.apple.com", + Now: nowFunc(now), + JWKSCacheTTL: time.Hour, + }) + + for i := 0; i < 2; i++ { + if _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}); err != nil { + t.Fatalf("exchange %d: %v", i, err) + } + } + if got := fp.jwksCalls.Load(); got != 1 { + t.Errorf("jwks fetched %d times, want 1", got) + } +} + +func TestApple_UserPayloadName(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + key := newTestKey(t, "kid-A") + fp := newFakeProvider(t) + + const clientID = "apple-client-id" + privateKeyPEM := generateTestECDSAKey(t) + + idToken := key.signIDToken(t, map[string]any{ + "iss": "https://appleid.apple.com", + "sub": "apple-sub-1", + "aud": clientID, + "iat": now.Unix(), + "exp": now.Add(5 * time.Minute).Unix(), + "email": "user@example.com", + "email_verified": true, + }) + + fp.tokenHandler = jsonHandler(map[string]any{"id_token": idToken}) + fp.jwksHandler = rawHandler(http.StatusOK, "application/json", key.JWKJSON) + + exch := NewApple(AppleConfig{ + ClientID: clientID, + TeamID: "team-123", + KeyID: "key-123", + PrivateKey: privateKeyPEM, + TokenURL: fp.URL("/token"), + JWKSURL: fp.URL("/jwks"), + Issuer: "https://appleid.apple.com", + Now: nowFunc(now), + }) + + ctx := context.Background() + id, err := exch.Exchange(ctx, ExchangeParams{ + Code: "the-code", + RedirectURI: "https://app/cb", + AppleUserPayload: `{"name":{"firstName":"John","lastName":"Doe"}}`, + }) + if err != nil { + t.Fatalf("Exchange: %v", err) + } + if id.Name != "John Doe" { + t.Errorf("expected Name 'John Doe', got %q", id.Name) + } +} diff --git a/pkg/oauth/authorize.go b/pkg/oauth/authorize.go index 9d847cfd..5fcac832 100644 --- a/pkg/oauth/authorize.go +++ b/pkg/oauth/authorize.go @@ -78,4 +78,5 @@ var ( _ Authorizer = (*googleExchanger)(nil) _ Authorizer = (*microsoftExchanger)(nil) _ Authorizer = (*githubExchanger)(nil) + _ Authorizer = (*appleExchanger)(nil) ) diff --git a/pkg/oauth/authorize_test.go b/pkg/oauth/authorize_test.go index 76f69e2d..01d32ba6 100644 --- a/pkg/oauth/authorize_test.go +++ b/pkg/oauth/authorize_test.go @@ -25,21 +25,6 @@ func newOAuthStateSignerWithKID(t *testing.T, kid string) *jwttest.Signer { return jwttest.NewSigner(t, kid) } -func TestCodeVerifierContext(t *testing.T) { - t.Parallel() - - var nilContext context.Context - if got := codeVerifierFromContext(nilContext); got != "" { - t.Fatalf("nil context verifier = %q", got) - } - if got := codeVerifierFromContext(WithCodeVerifier(context.Background(), " ")); got != "" { - t.Fatalf("blank verifier = %q", got) - } - if got := codeVerifierFromContext(WithCodeVerifier(context.Background(), " verifier-123 ")); got != "verifier-123" { - t.Fatalf("verifier = %q", got) - } -} - func TestStateToken_RoundTripAndMismatch(t *testing.T) { t.Parallel() diff --git a/pkg/oauth/context.go b/pkg/oauth/context.go deleted file mode 100644 index 0061efde..00000000 --- a/pkg/oauth/context.go +++ /dev/null @@ -1,25 +0,0 @@ -package oauth - -import ( - "context" - "strings" -) - -type codeVerifierContextKey struct{} - -// WithCodeVerifier carries the request-scoped PKCE verifier to the -// provider token exchange. -func WithCodeVerifier(ctx context.Context, codeVerifier string) context.Context { - if strings.TrimSpace(codeVerifier) == "" { - return ctx - } - return context.WithValue(ctx, codeVerifierContextKey{}, codeVerifier) -} - -func codeVerifierFromContext(ctx context.Context) string { - if ctx == nil { - return "" - } - codeVerifier, _ := ctx.Value(codeVerifierContextKey{}).(string) - return strings.TrimSpace(codeVerifier) -} diff --git a/pkg/oauth/extra_test.go b/pkg/oauth/extra_test.go index 02071778..fec04421 100644 --- a/pkg/oauth/extra_test.go +++ b/pkg/oauth/extra_test.go @@ -35,29 +35,29 @@ func TestNewExchangers_DefaultsApplied(t *testing.T) { func TestExchanger_MissingCodeOrCreds(t *testing.T) { t.Parallel() g := NewGoogle(GoogleConfig{ClientID: "x", ClientSecret: "y"}) - if _, err := g.Exchange(context.Background(), "", "https://x"); !errors.Is(err, ErrCodeExchangeFailed) { + if _, err := g.Exchange(context.Background(), ExchangeParams{Code: "", RedirectURI: "https://x"}); !errors.Is(err, ErrCodeExchangeFailed) { t.Errorf("google empty code: %v", err) } g2 := NewGoogle(GoogleConfig{}) - if _, err := g2.Exchange(context.Background(), "code", "https://x"); !errors.Is(err, ErrCodeExchangeFailed) { + if _, err := g2.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}); !errors.Is(err, ErrCodeExchangeFailed) { t.Errorf("google empty creds: %v", err) } m := NewMicrosoft(MicrosoftConfig{}) - if _, err := m.Exchange(context.Background(), "code", "https://x"); !errors.Is(err, ErrCodeExchangeFailed) { + if _, err := m.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}); !errors.Is(err, ErrCodeExchangeFailed) { t.Errorf("ms empty creds: %v", err) } m2 := NewMicrosoft(MicrosoftConfig{ClientID: "x", ClientSecret: "y"}) - if _, err := m2.Exchange(context.Background(), "", "https://x"); !errors.Is(err, ErrCodeExchangeFailed) { + if _, err := m2.Exchange(context.Background(), ExchangeParams{Code: "", RedirectURI: "https://x"}); !errors.Is(err, ErrCodeExchangeFailed) { t.Errorf("ms empty code: %v", err) } gh := NewGitHub(GitHubConfig{}) - if _, err := gh.Exchange(context.Background(), "code", "https://x"); !errors.Is(err, ErrCodeExchangeFailed) { + if _, err := gh.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}); !errors.Is(err, ErrCodeExchangeFailed) { t.Errorf("gh empty creds: %v", err) } gh2 := NewGitHub(GitHubConfig{ClientID: "x", ClientSecret: "y"}) - if _, err := gh2.Exchange(context.Background(), "", "https://x"); !errors.Is(err, ErrCodeExchangeFailed) { + if _, err := gh2.Exchange(context.Background(), ExchangeParams{Code: "", RedirectURI: "https://x"}); !errors.Is(err, ErrCodeExchangeFailed) { t.Errorf("gh empty code: %v", err) } } @@ -200,7 +200,7 @@ func TestGoogle_JWKSRotationRetrySucceeds(t *testing.T) { Now: nowFunc(now), HTTPClient: srv.Client(), }) - if _, err := exch.Exchange(context.Background(), "code", "https://x"); err != nil { + if _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}); err != nil { t.Fatalf("Exchange after rotation: %v", err) } if got := jwksHits.Load(); got != 2 { @@ -233,7 +233,7 @@ func TestMicrosoft_MissingTID(t *testing.T) { IssuerFormat: "https://login.test/%s/v2.0", Now: nowFunc(now), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrIdentityVerification) { t.Fatalf("want ErrIdentityVerification, got %v", err) } @@ -265,7 +265,7 @@ func TestMicrosoft_NoEmailFails(t *testing.T) { IssuerFormat: "https://login.test/%s/v2.0", Now: nowFunc(now), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrIdentityVerification) { t.Fatalf("want ErrIdentityVerification, got %v", err) } @@ -282,7 +282,7 @@ func TestGoogle_NoIDToken(t *testing.T) { TokenURL: fp.URL("/token"), JWKSURL: fp.URL("/jwks"), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrCodeExchangeFailed) { t.Fatalf("want ErrCodeExchangeFailed, got %v", err) } @@ -304,7 +304,7 @@ func TestGitHub_EmailEndpointBadJSON(t *testing.T) { UserURL: fp.URL("/user"), UserMailURL: fp.URL("/user/emails"), }) - if _, err := exch.Exchange(context.Background(), "code", "https://x"); err == nil { + if _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}); err == nil { t.Fatal("expected parse error") } } @@ -344,7 +344,7 @@ func TestGoogle_Exchange_BadResponseJSON(t *testing.T) { TokenURL: fp.URL("/token"), JWKSURL: fp.URL("/jwks"), }) - if _, err := exch.Exchange(context.Background(), "code", "https://x"); !errors.Is(err, ErrCodeExchangeFailed) { + if _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}); !errors.Is(err, ErrCodeExchangeFailed) { t.Fatalf("want ErrCodeExchangeFailed, got %v", err) } } @@ -361,7 +361,7 @@ func TestMicrosoft_Exchange_NoIDToken(t *testing.T) { JWKSURL: fp.URL("/jwks"), IssuerFormat: "https://x/%s/v2.0", }) - if _, err := exch.Exchange(context.Background(), "code", "https://x"); !errors.Is(err, ErrCodeExchangeFailed) { + if _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}); !errors.Is(err, ErrCodeExchangeFailed) { t.Fatalf("want ErrCodeExchangeFailed, got %v", err) } } @@ -378,7 +378,7 @@ func TestGitHub_NoAccessToken(t *testing.T) { UserURL: fp.URL("/user"), UserMailURL: fp.URL("/user/emails"), }) - if _, err := exch.Exchange(context.Background(), "code", "https://x"); !errors.Is(err, ErrCodeExchangeFailed) { + if _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}); !errors.Is(err, ErrCodeExchangeFailed) { t.Fatalf("want ErrCodeExchangeFailed, got %v", err) } } @@ -396,7 +396,7 @@ func TestGitHub_UserMissingID(t *testing.T) { UserURL: fp.URL("/user"), UserMailURL: fp.URL("/user/emails"), }) - if _, err := exch.Exchange(context.Background(), "code", "https://x"); !errors.Is(err, ErrIdentityVerification) { + if _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}); !errors.Is(err, ErrIdentityVerification) { t.Fatalf("want ErrIdentityVerification, got %v", err) } } @@ -416,7 +416,7 @@ func TestGitHub_UserBadJSON(t *testing.T) { UserURL: fp.URL("/user"), UserMailURL: fp.URL("/user/emails"), }) - if _, err := exch.Exchange(context.Background(), "code", "https://x"); err == nil { + if _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}); err == nil { t.Fatal("expected parse error") } } diff --git a/pkg/oauth/fuzz_test.go b/pkg/oauth/fuzz_test.go new file mode 100644 index 00000000..ca56ab5b --- /dev/null +++ b/pkg/oauth/fuzz_test.go @@ -0,0 +1,86 @@ +package oauth + +import ( + "bytes" + "context" + "io" + "net/http" + "strings" + "testing" + "time" +) + +// FuzzVerifyIDToken is a Go fuzz target that feeds arbitrary bytes to the +// provider verify functions to prove they don't panic on garbage input. +func FuzzVerifyIDToken(f *testing.F) { + // Seed the corpus with some interesting strings. + f.Add("") + f.Add("garbage") + f.Add("eyJhbGciOiJSUzI1NiIsImtpZCI6ImtpZC1BIn0.garbage.sig") + f.Add("eyJhbGciOiJSUzI1NiIsImtpZCI6ImtpZC1BIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.sig") + + now := time.Date(2026, 5, 3, 12, 0, 0, 0, time.UTC) + key := newTestKey(f, "kid-A") + + client := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) *http.Response { + if strings.HasSuffix(req.URL.Path, "/jwks") { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(key.JWKJSON)), + Header: make(http.Header), + } + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"id_token":"ignore"}`)), + Header: make(http.Header), + } + }), + } + + //nolint:gosec // test fake credentials + apple := NewApple(AppleConfig{ + ClientID: "client", + TeamID: "team", + KeyID: "key", + PrivateKey: generateTestECDSAKey(f), + TokenURL: "https://apple/token", + JWKSURL: "https://apple/jwks", + Issuer: "https://appleid.apple.com", + Now: nowFunc(now), + HTTPClient: client, + }) + + //nolint:gosec // test fake credentials + google := NewGoogle(GoogleConfig{ + ClientID: "client", + TokenURL: "https://google/token", + JWKSURL: "https://google/jwks", + Issuer: "https://accounts.test", + Now: nowFunc(now), + HTTPClient: client, + }) + + //nolint:gosec // test fake credentials + microsoft := NewMicrosoft(MicrosoftConfig{ + ClientID: "client", + TokenURL: "https://microsoft/token", + JWKSURL: "https://microsoft/jwks", + Now: nowFunc(now), + HTTPClient: client, + }) + + f.Fuzz(func(t *testing.T, token string) { + ctx := context.Background() + _, _ = apple.(*appleExchanger).verifyIDToken(ctx, token) + _, _ = google.(*googleExchanger).verifyIDToken(ctx, token) + _, _ = microsoft.(*microsoftExchanger).verifyIDToken(ctx, token) + }) +} + +type roundTripFunc func(req *http.Request) *http.Response + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req), nil +} diff --git a/pkg/oauth/github.go b/pkg/oauth/github.go index 886fb691..934398b8 100644 --- a/pkg/oauth/github.go +++ b/pkg/oauth/github.go @@ -98,8 +98,8 @@ type githubEmail struct { Verified bool `json:"verified"` } -func (g *githubExchanger) Exchange(ctx context.Context, code, redirectURI string) (*Identity, error) { - if code == "" { +func (g *githubExchanger) Exchange(ctx context.Context, params ExchangeParams) (*Identity, error) { + if params.Code == "" { return nil, fmt.Errorf("%w: missing authorization code", ErrCodeExchangeFailed) } if g.cfg.ClientID == "" || g.cfg.ClientSecret == "" { @@ -109,12 +109,12 @@ func (g *githubExchanger) Exchange(ctx context.Context, code, redirectURI string form := url.Values{} form.Set("client_id", g.cfg.ClientID) form.Set("client_secret", g.cfg.ClientSecret) - form.Set("code", code) - if redirectURI != "" { - form.Set("redirect_uri", redirectURI) + form.Set("code", params.Code) + if params.RedirectURI != "" { + form.Set("redirect_uri", params.RedirectURI) } - if codeVerifier := codeVerifierFromContext(ctx); codeVerifier != "" { - form.Set("code_verifier", codeVerifier) + if params.CodeVerifier != "" { + form.Set("code_verifier", params.CodeVerifier) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, g.cfg.TokenURL, diff --git a/pkg/oauth/github_test.go b/pkg/oauth/github_test.go index 59c88b66..0f8d0001 100644 --- a/pkg/oauth/github_test.go +++ b/pkg/oauth/github_test.go @@ -35,7 +35,7 @@ func TestGitHub_ExchangeSuccess(t *testing.T) { UserURL: fp.URL("/user"), UserMailURL: fp.URL("/user/emails"), }) - id, err := exch.Exchange(context.Background(), "code", "https://x") + id, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err != nil { t.Fatalf("Exchange: %v", err) } @@ -76,7 +76,7 @@ func TestGitHub_FallsBackToProfileEmail(t *testing.T) { UserURL: fp.URL("/user"), UserMailURL: fp.URL("/user/emails"), }) - id, err := exch.Exchange(context.Background(), "code", "https://x") + id, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err != nil { t.Fatalf("Exchange: %v", err) } @@ -103,7 +103,7 @@ func TestGitHub_NoVerifiedEmailRejected(t *testing.T) { UserURL: fp.URL("/user"), UserMailURL: fp.URL("/user/emails"), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrEmailNotVerified) { t.Fatalf("want ErrEmailNotVerified, got %v", err) } @@ -123,7 +123,7 @@ func TestGitHub_TokenError(t *testing.T) { UserURL: fp.URL("/user"), UserMailURL: fp.URL("/user/emails"), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrCodeExchangeFailed) { t.Fatalf("want ErrCodeExchangeFailed, got %v", err) } @@ -142,7 +142,7 @@ func TestGitHub_TokenEndpoint500(t *testing.T) { UserURL: fp.URL("/user"), UserMailURL: fp.URL("/user/emails"), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrCodeExchangeFailed) { t.Fatalf("want ErrCodeExchangeFailed, got %v", err) } @@ -162,7 +162,7 @@ func TestGitHub_UserEndpointFailure(t *testing.T) { UserURL: fp.URL("/user"), UserMailURL: fp.URL("/user/emails"), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrIdentityVerification) { t.Fatalf("want ErrIdentityVerification, got %v", err) } diff --git a/pkg/oauth/google.go b/pkg/oauth/google.go index 099adcc2..a7329910 100644 --- a/pkg/oauth/google.go +++ b/pkg/oauth/google.go @@ -118,8 +118,8 @@ type googleTokenResponse struct { ErrorDesc string `json:"error_description"` } -func (g *googleExchanger) Exchange(ctx context.Context, code, redirectURI string) (*Identity, error) { - if code == "" { +func (g *googleExchanger) Exchange(ctx context.Context, params ExchangeParams) (*Identity, error) { + if params.Code == "" { return nil, fmt.Errorf("%w: missing authorization code", ErrCodeExchangeFailed) } if g.cfg.ClientID == "" || g.cfg.ClientSecret == "" { @@ -151,13 +151,13 @@ func (g *googleExchanger) Exchange(ctx context.Context, code, redirectURI string } form := url.Values{} - form.Set("code", code) + form.Set("code", params.Code) form.Set("client_id", g.cfg.ClientID) form.Set("client_secret", g.cfg.ClientSecret) - form.Set("redirect_uri", redirectURI) + form.Set("redirect_uri", params.RedirectURI) form.Set("grant_type", "authorization_code") - if codeVerifier := codeVerifierFromContext(ctx); codeVerifier != "" { - form.Set("code_verifier", codeVerifier) + if params.CodeVerifier != "" { + form.Set("code_verifier", params.CodeVerifier) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, @@ -281,18 +281,18 @@ func (g *googleExchanger) verifyIDToken(ctx context.Context, raw string) (*googl } payload, err := verifyJWS(raw, set) - if err != nil { - // On verification failure we may be looking at a stale cache - // after a key rotation. Invalidate and try once more. + if err != nil && errors.Is(err, errKeyNotFound) { + // On verification failure due to missing key, we may be looking + // at a stale cache after a key rotation. Invalidate and try once more. g.jwks.Invalidate() set2, fErr := g.jwks.Get(ctx) if fErr != nil { return nil, fmt.Errorf("%w: %w", ErrIdentityVerification, err) } payload, err = verifyJWS(raw, set2) - if err != nil { - return nil, fmt.Errorf("%w: %w", ErrIdentityVerification, err) - } + } + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrIdentityVerification, err) } // Decode the JWT for issuer/audience/exp checks via jwx's parser @@ -329,6 +329,8 @@ func (g *googleExchanger) verifyIDToken(ctx context.Context, raw string) (*googl return &claims, nil } +var errKeyNotFound = errors.New("key not found in jwks") + // verifyJWS verifies the signature on a compact JWS using the // provided JWK set (matching kid → key). Returns the decoded payload // bytes on success. @@ -358,17 +360,17 @@ func verifyJWS(raw string, set jwk.Set) ([]byte, error) { if kid != "" { k, ok := set.LookupKeyID(kid) if !ok { - return nil, fmt.Errorf("no jwk for kid=%q", kid) + return nil, fmt.Errorf("%w: kid=%q", errKeyNotFound, kid) } key = k } else { // No kid — try the first key. if set.Len() == 0 { - return nil, errors.New("jwks empty") + return nil, fmt.Errorf("%w: empty", errKeyNotFound) } k, ok := set.Key(0) if !ok { - return nil, errors.New("jwks first key missing") + return nil, fmt.Errorf("%w: first key missing", errKeyNotFound) } key = k } diff --git a/pkg/oauth/google_test.go b/pkg/oauth/google_test.go index f98ebf20..482e0fa5 100644 --- a/pkg/oauth/google_test.go +++ b/pkg/oauth/google_test.go @@ -47,7 +47,7 @@ func TestGoogle_ExchangeSuccess(t *testing.T) { Now: nowFunc(now), }) - id, err := exch.Exchange(context.Background(), "the-code", "https://app/cb") + id, err := exch.Exchange(context.Background(), ExchangeParams{Code: "the-code", RedirectURI: "https://app/cb"}) if err != nil { t.Fatalf("Exchange: %v", err) } @@ -78,7 +78,7 @@ func TestGoogle_TokenEndpointError(t *testing.T) { TokenURL: fp.URL("/token"), JWKSURL: fp.URL("/jwks"), }) - _, err := exch.Exchange(context.Background(), "bad", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "bad", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrCodeExchangeFailed) { t.Fatalf("want ErrCodeExchangeFailed, got %v", err) } @@ -94,7 +94,7 @@ func TestGoogle_NetworkError(t *testing.T) { JWKSURL: "http://127.0.0.1:1/", HTTPClient: &http.Client{Timeout: 50 * time.Millisecond}, }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrCodeExchangeFailed) { t.Fatalf("want ErrCodeExchangeFailed, got %v", err) } @@ -129,7 +129,7 @@ func TestGoogle_BadSignature(t *testing.T) { Issuer: "https://accounts.test", Now: nowFunc(now), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrIdentityVerification) { t.Fatalf("want ErrIdentityVerification, got %v", err) } @@ -160,7 +160,7 @@ func TestGoogle_BadIssuer(t *testing.T) { Issuer: "https://accounts.test", Now: nowFunc(now), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrIdentityVerification) { t.Fatalf("want ErrIdentityVerification, got %v", err) } @@ -191,7 +191,7 @@ func TestGoogle_BadAudience(t *testing.T) { Issuer: "https://accounts.test", Now: nowFunc(now), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrIdentityVerification) { t.Fatalf("want ErrIdentityVerification, got %v", err) } @@ -222,7 +222,7 @@ func TestGoogle_ExpiredToken(t *testing.T) { Issuer: "https://accounts.test", Now: nowFunc(now), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrIdentityVerification) { t.Fatalf("want ErrIdentityVerification, got %v", err) } @@ -256,7 +256,7 @@ func TestGoogle_EmailNotVerified(t *testing.T) { Issuer: "https://accounts.test", Now: nowFunc(now), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrEmailNotVerified) { t.Fatalf("want ErrEmailNotVerified, got %v", err) } @@ -297,7 +297,7 @@ func TestGoogle_JWKSCaching(t *testing.T) { // Two successful exchanges should hit JWKS only once. for i := 0; i < 2; i++ { - if _, err := exch.Exchange(context.Background(), "code", "https://x"); err != nil { + if _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}); err != nil { t.Fatalf("exchange %d: %v", i, err) } } diff --git a/pkg/oauth/hosted_state.go b/pkg/oauth/hosted_state.go index 8b1a5865..c3d7fc93 100644 --- a/pkg/oauth/hosted_state.go +++ b/pkg/oauth/hosted_state.go @@ -32,6 +32,7 @@ type HostedStateClaims struct { ReturnTo string State string CodeVerifier string + CSRFToken string IssuedAt int64 ExpiresAt int64 } @@ -42,7 +43,7 @@ type HostedStateClaims struct { func IssueHostedStateToken( ctx context.Context, signer identityjwt.Signer, - provider, redirectURI, returnTo, state, codeVerifier string, + provider, redirectURI, returnTo, state, codeVerifier, csrfToken string, expiry time.Duration, now time.Time, ) (string, error) { @@ -51,7 +52,8 @@ func IssueHostedStateToken( returnTo = strings.TrimSpace(returnTo) state = strings.TrimSpace(state) codeVerifier = strings.TrimSpace(codeVerifier) - if provider == "" || redirectURI == "" || returnTo == "" || state == "" || codeVerifier == "" { + csrfToken = strings.TrimSpace(csrfToken) + if provider == "" || redirectURI == "" || returnTo == "" || state == "" || codeVerifier == "" || csrfToken == "" { return "", fmt.Errorf("%w: missing required hosted-state claim", ErrStateValidation) } if signer == nil { @@ -65,6 +67,7 @@ func IssueHostedStateToken( "return_to": returnTo, "oauth_state": state, "code_verifier": codeVerifier, + "csrf_token": csrfToken, "iat": now.Unix(), "exp": now.Add(expiry).Unix(), } @@ -132,11 +135,12 @@ func VerifyHostedStateToken( ReturnTo: getStringClaim(tok, "return_to"), State: getStringClaim(tok, "oauth_state"), CodeVerifier: getStringClaim(tok, "code_verifier"), + CSRFToken: getStringClaim(tok, "csrf_token"), IssuedAt: tok.IssuedAt().Unix(), ExpiresAt: tok.Expiration().Unix(), } if claims.Provider == "" || claims.RedirectURI == "" || claims.ReturnTo == "" || - claims.State == "" || claims.CodeVerifier == "" { + claims.State == "" || claims.CodeVerifier == "" || claims.CSRFToken == "" { return nil, fmt.Errorf("%w: hosted state token missing required claims", ErrStateValidation) } if exp := tok.Expiration(); exp.IsZero() || now.After(exp) { diff --git a/pkg/oauth/hosted_state_test.go b/pkg/oauth/hosted_state_test.go index a0f76fc4..e76ac8ad 100644 --- a/pkg/oauth/hosted_state_test.go +++ b/pkg/oauth/hosted_state_test.go @@ -20,6 +20,7 @@ func TestHostedStateToken_RoundTrip(t *testing.T) { "https://app.example.com/finish", "state-abc", "verifier-abc", + "csrf-123", 5*time.Minute, now, ) @@ -43,6 +44,9 @@ func TestHostedStateToken_RoundTrip(t *testing.T) { if claims.RedirectURI != "https://identity.example.com/oauth/callback/google" { t.Errorf("redirect_uri = %q", claims.RedirectURI) } + if claims.CSRFToken != "csrf-123" { + t.Errorf("csrf_token = %q", claims.CSRFToken) + } } func TestHostedStateToken_RejectsExpired(t *testing.T) { @@ -53,7 +57,7 @@ func TestHostedStateToken_RejectsExpired(t *testing.T) { token, err := IssueHostedStateToken( context.Background(), ring, - "google", "https://identity.example.com/cb", "https://app.example.com/", "s", "v", + "google", "https://identity.example.com/cb", "https://app.example.com/", "s", "v", "c", 5*time.Minute, now, ) if err != nil { @@ -72,7 +76,7 @@ func TestHostedStateToken_RejectsTampered(t *testing.T) { token, err := IssueHostedStateToken( context.Background(), ring, - "google", "https://identity.example.com/cb", "https://app.example.com/", "s", "v", + "google", "https://identity.example.com/cb", "https://app.example.com/", "s", "v", "c", 5*time.Minute, now, ) if err != nil { @@ -111,7 +115,7 @@ func TestHostedStateToken_MissingClaims(t *testing.T) { ring := newOAuthStateSigner(t) now := time.Now().UTC() if _, err := IssueHostedStateToken( - context.Background(), ring, "google", "https://cb", "", "s", "v", time.Minute, now, + context.Background(), ring, "google", "https://cb", "", "s", "v", "c", time.Minute, now, ); err == nil { t.Fatal("IssueHostedStateToken should reject an empty return_to") } diff --git a/pkg/oauth/microsoft.go b/pkg/oauth/microsoft.go index efc2762c..2e0c0ad6 100644 --- a/pkg/oauth/microsoft.go +++ b/pkg/oauth/microsoft.go @@ -3,6 +3,7 @@ package oauth import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -115,23 +116,20 @@ type microsoftIDClaims struct { VerifiedEmail *bool `json:"verified_email"` } -func (m *microsoftExchanger) Exchange(ctx context.Context, code, redirectURI string) (*Identity, error) { - if code == "" { - return nil, fmt.Errorf("%w: missing authorization code", ErrCodeExchangeFailed) - } +func (m *microsoftExchanger) Exchange(ctx context.Context, params ExchangeParams) (*Identity, error) { if m.cfg.ClientID == "" || m.cfg.ClientSecret == "" { return nil, fmt.Errorf("%w: client credentials not configured", ErrCodeExchangeFailed) } form := url.Values{} - form.Set("code", code) + form.Set("code", params.Code) form.Set("client_id", m.cfg.ClientID) form.Set("client_secret", m.cfg.ClientSecret) - form.Set("redirect_uri", redirectURI) + form.Set("redirect_uri", params.RedirectURI) form.Set("grant_type", "authorization_code") form.Set("scope", "openid email profile") - if codeVerifier := codeVerifierFromContext(ctx); codeVerifier != "" { - form.Set("code_verifier", codeVerifier) + if params.CodeVerifier != "" { + form.Set("code_verifier", params.CodeVerifier) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, m.cfg.TokenURL, @@ -234,16 +232,16 @@ func (m *microsoftExchanger) verifyIDToken(ctx context.Context, raw string) (*mi } payload, err := verifyJWS(raw, set) - if err != nil { + if err != nil && errors.Is(err, errKeyNotFound) { m.jwks.Invalidate() set2, fErr := m.jwks.Get(ctx) if fErr != nil { return nil, fmt.Errorf("%w: %w", ErrIdentityVerification, err) } payload, err = verifyJWS(raw, set2) - if err != nil { - return nil, fmt.Errorf("%w: %w", ErrIdentityVerification, err) - } + } + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrIdentityVerification, err) } tok, err := jwt.Parse(payload, jwt.WithVerify(false), jwt.WithValidate(false)) diff --git a/pkg/oauth/microsoft_test.go b/pkg/oauth/microsoft_test.go index a38d15c7..e18852a3 100644 --- a/pkg/oauth/microsoft_test.go +++ b/pkg/oauth/microsoft_test.go @@ -44,7 +44,7 @@ func TestMicrosoft_ExchangeSuccess(t *testing.T) { IssuerFormat: "https://login.test/%s/v2.0", Now: nowFunc(now), }) - id, err := exch.Exchange(context.Background(), "code", "https://app/cb") + id, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://app/cb"}) if err != nil { t.Fatalf("Exchange: %v", err) } @@ -87,7 +87,7 @@ func TestMicrosoft_EmailFromUPNFallback(t *testing.T) { IssuerFormat: "https://login.test/%s/v2.0", Now: nowFunc(now), }) - id, err := exch.Exchange(context.Background(), "code", "https://x") + id, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err != nil { t.Fatalf("Exchange: %v", err) } @@ -122,7 +122,7 @@ func TestMicrosoft_BadIssuer(t *testing.T) { IssuerFormat: "https://login.test/%s/v2.0", Now: nowFunc(now), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrIdentityVerification) { t.Fatalf("want ErrIdentityVerification, got %v", err) } @@ -154,7 +154,7 @@ func TestMicrosoft_BadAudience(t *testing.T) { IssuerFormat: "https://login.test/%s/v2.0", Now: nowFunc(now), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrIdentityVerification) { t.Fatalf("want ErrIdentityVerification, got %v", err) } @@ -186,7 +186,7 @@ func TestMicrosoft_ExpiredToken(t *testing.T) { IssuerFormat: "https://login.test/%s/v2.0", Now: nowFunc(now), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrIdentityVerification) { t.Fatalf("want ErrIdentityVerification, got %v", err) } @@ -219,7 +219,7 @@ func TestMicrosoft_VerifiedEmailFalse(t *testing.T) { IssuerFormat: "https://login.test/%s/v2.0", Now: nowFunc(now), }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrEmailNotVerified) { t.Fatalf("want ErrEmailNotVerified, got %v", err) } @@ -238,7 +238,7 @@ func TestMicrosoft_TokenEndpoint400(t *testing.T) { JWKSURL: fp.URL("/jwks"), IssuerFormat: "https://login.test/%s/v2.0", }) - _, err := exch.Exchange(context.Background(), "code", "https://x") + _, err := exch.Exchange(context.Background(), ExchangeParams{Code: "code", RedirectURI: "https://x"}) if err == nil || !errors.Is(err, ErrCodeExchangeFailed) { t.Fatalf("want ErrCodeExchangeFailed, got %v", err) } diff --git a/pkg/oauth/oauth_testutil_test.go b/pkg/oauth/oauth_testutil_test.go index 3ebe5f06..2abe6452 100644 --- a/pkg/oauth/oauth_testutil_test.go +++ b/pkg/oauth/oauth_testutil_test.go @@ -25,29 +25,29 @@ type testKey struct { JWKJSON []byte } -func newTestKey(t *testing.T, kid string) *testKey { - t.Helper() +func newTestKey(tb testing.TB, kid string) *testKey { + tb.Helper() priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - t.Fatalf("rsa generate: %v", err) + tb.Fatalf("rsa generate: %v", err) } pubKey, err := jwk.FromRaw(priv.Public()) if err != nil { - t.Fatalf("jwk from raw: %v", err) + tb.Fatalf("jwk from raw: %v", err) } if err := pubKey.Set(jwk.KeyIDKey, kid); err != nil { - t.Fatalf("set kid: %v", err) + tb.Fatalf("set kid: %v", err) } if err := pubKey.Set(jwk.AlgorithmKey, jwa.RS256); err != nil { - t.Fatalf("set alg: %v", err) + tb.Fatalf("set alg: %v", err) } set := jwk.NewSet() if err := set.AddKey(pubKey); err != nil { - t.Fatalf("add key: %v", err) + tb.Fatalf("add key: %v", err) } jwksJSON, err := json.Marshal(set) if err != nil { - t.Fatalf("marshal jwks: %v", err) + tb.Fatalf("marshal jwks: %v", err) } return &testKey{Priv: priv, JWKSet: set, KID: kid, JWKJSON: jwksJSON} } diff --git a/pkg/oauth/provider.go b/pkg/oauth/provider.go index 3444921c..727cf3d4 100644 --- a/pkg/oauth/provider.go +++ b/pkg/oauth/provider.go @@ -31,10 +31,18 @@ type Identity struct { // AvatarURL is a URL to the user's profile picture. May be empty. AvatarURL string - // Provider is the provider key — "google", "microsoft", "github". + // Provider is the provider key — "google", "microsoft", "github", "apple". Provider string } +// ExchangeParams contains the arguments for the OAuth token exchange. +type ExchangeParams struct { + Code string + RedirectURI string + CodeVerifier string // Optional PKCE code_verifier + AppleUserPayload string // Optional form-post payload for Apple first-time login +} + // Exchanger swaps an OAuth authorization code for a verified user // identity. Implementations are responsible for: // @@ -49,7 +57,7 @@ type Identity struct { // leak provider response bodies. Callers that need provider-specific // debugging should inspect logs. type Exchanger interface { - Exchange(ctx context.Context, code, redirectURI string) (*Identity, error) + Exchange(ctx context.Context, params ExchangeParams) (*Identity, error) } // Authorizer builds the provider authorization URL for the first half diff --git a/pkg/oauth/registry.go b/pkg/oauth/registry.go index 14070a29..ae24e207 100644 --- a/pkg/oauth/registry.go +++ b/pkg/oauth/registry.go @@ -2,7 +2,7 @@ package oauth import "sync" -// Registry maps provider keys ("google", "microsoft", "github") to +// Registry maps provider keys ("google", "microsoft", "github", "apple") to // their Exchanger implementations. The service layer looks up the // Exchanger for the provider named in the OAuthLoginRequest. // diff --git a/pkg/oauth/registry_test.go b/pkg/oauth/registry_test.go index 9601cb0f..3c80cf16 100644 --- a/pkg/oauth/registry_test.go +++ b/pkg/oauth/registry_test.go @@ -7,7 +7,7 @@ import ( type stubExchanger struct{ name string } -func (s stubExchanger) Exchange(_ context.Context, _, _ string) (*Identity, error) { +func (s stubExchanger) Exchange(_ context.Context, _ ExchangeParams) (*Identity, error) { return &Identity{Provider: s.name, Email: "x@x", EmailVerified: true}, nil } diff --git a/proto/identity/v1/identity.proto b/proto/identity/v1/identity.proto index b5a69002..ad19fa1c 100644 --- a/proto/identity/v1/identity.proto +++ b/proto/identity/v1/identity.proto @@ -199,7 +199,7 @@ message ListGroupMembersResponse { message BeginOAuthLoginRequest { string redirect_uri = 1; - string provider = 2; // "google", "microsoft", or "github" + string provider = 2; // "google", "microsoft", "github", or "apple" string tenant = 3; // Microsoft tenant ID (optional) } @@ -214,11 +214,12 @@ message BeginOAuthLoginResponse { message OAuthLoginRequest { string code = 1; string redirect_uri = 2; - string provider = 3; // "google", "microsoft", or "github" + string provider = 3; // "google", "microsoft", "github", or "apple" string code_verifier = 4; // PKCE (optional) string tenant = 5; // Microsoft tenant ID (optional) string state = 6; // authorization callback state (required for server-owned flow) string state_token = 7; // opaque token minted by BeginOAuthLogin + string apple_user_payload = 8; // one-time user payload from Apple's form_post callback } message OAuthLoginResponse { @@ -230,7 +231,7 @@ message OAuthLoginResponse { } // RedeemOAuthCode exchanges the single-use one-time code handed to the -// SPA by the hosted OAuth callback (GET /oauth/callback/{provider} -> +// SPA by the hosted OAuth callback (GET/POST /oauth/callback/{provider} -> // 302 return_to?code=) for a backend-issued token pair. The code // is single-use and short-lived; a replay returns CodeUnauthenticated. message RedeemOAuthCodeRequest { @@ -1243,7 +1244,7 @@ message GetProjectConfigResponse { // LinkedIdentity is one connected provider identity for the authenticated user. message LinkedIdentity { - string provider = 1; // "google", "microsoft", "github", ... + string provider = 1; // "google", "microsoft", "github", "apple", ... string provider_user_id = 2; // stable provider subject string email_at_link_time = 3; // email the provider asserted when linked int64 linked_at = 4; // epoch ms the link was created @@ -1264,7 +1265,7 @@ message ListLinkedIdentitiesResponse { message LinkIdentityRequest { string code = 1; // authorization code from the provider callback string redirect_uri = 2; - string provider = 3; // "google", "microsoft", "github", ... + string provider = 3; // "google", "microsoft", "github", "apple", ... string code_verifier = 4; // PKCE (optional) string state = 5; // authorization callback state string state_token = 6; // opaque token minted by BeginOAuthLogin diff --git a/tests/integration/oauth_hosted_test.go b/tests/integration/oauth_hosted_test.go index 78d1d614..b2d0fb8b 100644 --- a/tests/integration/oauth_hosted_test.go +++ b/tests/integration/oauth_hosted_test.go @@ -33,7 +33,7 @@ func (p *hostedStubProvider) AuthorizationURL(_ context.Context, redirectURI, st return u.String(), nil } -func (p *hostedStubProvider) Exchange(_ context.Context, _, _ string) (*oauth.Identity, error) { +func (p *hostedStubProvider) Exchange(_ context.Context, params oauth.ExchangeParams) (*oauth.Identity, error) { return p.identity, nil } @@ -105,7 +105,14 @@ func TestHostedOAuth_EndToEnd(t *testing.T) { // 2. Callback: provider redirects back with state + code. Expect a // 302 to return_to?code=. callbackURL := h.BaseURL + "/oauth/callback/google?state=" + url.QueryEscape(stateToken) + "&code=auth-code-xyz" - cbResp, err := client.Get(callbackURL) + req, err := http.NewRequest("GET", callbackURL, nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + for _, c := range startResp.Cookies() { + req.AddCookie(c) + } + cbResp, err := client.Do(req) if err != nil { t.Fatalf("GET /oauth/callback: %v", err) } diff --git a/tests/integration/oauth_login_test.go b/tests/integration/oauth_login_test.go index ecdfbcce..694fe23d 100644 --- a/tests/integration/oauth_login_test.go +++ b/tests/integration/oauth_login_test.go @@ -22,7 +22,7 @@ type staticExchanger struct { err error } -func (s *staticExchanger) Exchange(_ context.Context, code, _ string) (*oauth.Identity, error) { +func (s *staticExchanger) Exchange(_ context.Context, params oauth.ExchangeParams) (*oauth.Identity, error) { if s.err != nil { return nil, s.err }