import { JWTVerifyGetKey, errors, importSPKI, jwtVerify } from 'jose'
import * as z from 'zod'
import { zodInvariant } from '@purposity/utils'
import env from '../env/server'
import { HasuraNextJwtClaimsSchema } from '../lib/clerk/jwt-claims.purposity'
import { AuthError } from '../lib/util/errors'

export async function verifyClerkToken(token: string | unknown) {
  if (typeof token !== 'string') throw new AuthError(AuthError.INVALID_TOKEN)
  try {
    const verified = await jwtVerify(token, getClerkJwtKey)
    zodInvariant(HasuraNextJwtClaimsSchema, verified.payload)
    return verified.payload
  } catch (err) {
    if (err instanceof errors.JOSEError) {
      throw new AuthError(AuthError.INVALID_TOKEN, { cause: err })
    } else if (err instanceof z.ZodError) {
      throw new AuthError(AuthError.INVALID_CLAIMS, { cause: err })
    } else {
      throw err
    }
  }
}

const getClerkJwtKey: JWTVerifyGetKey = (protectedHeader, _token) => {
  const RE_SPKI_CONTENTS = /.{1,64}/g
  const spki = z
    .string()
    .regex(RE_SPKI_CONTENTS, { message: 'Invalid SPKI' })
    .transform((v) => v.match(RE_SPKI_CONTENTS) ?? [])
    .transform((v) => v.join('\n'))
    .transform(
      (v) => `-----BEGIN PUBLIC KEY-----\n${v}\n-----END PUBLIC KEY-----`
    )
    .parse(env.CLERK_JWT_KEY, {
      path: ['env', 'CLERK_JWT_KEY'],
    })

  return importSPKI(spki, protectedHeader.alg)
}

export async function verifyClerkTokenSafe(
  token: string | unknown
): Promise<
  | { success: true; data: z.infer<typeof HasuraNextJwtClaimsSchema> }
  | { success: false; error: AuthError }
> {
  try {
    const data = await verifyClerkToken(token)
    return {
      success: true,
      data,
    }
  } catch (err) {
    if (AuthError.isAuthError(err)) {
      return {
        success: false,
        error: err,
      }
    } else {
      return {
        success: false,
        error: new AuthError(AuthError.UNKNOWN, { cause: err }),
      }
    }
  }
}
