(ns codescene.features.util.jwt
  "JWT utilities. Use make-jwt, unpack-jwt and unsign-jws to create and verify tokens.

  When you use make-jwt, any private claims are encrypted per JWE standard. What you
  get is encrypted token that can be decrypted by someone using same shared secret.

  If there are public claims, those claims are signed using EdDSA key per JWS standard, and
  they can be verified using the public key.

  When both private and public claims are present, private claims JWE token added to public
  claims as key ? and signed.

  unpack-jwt returns unencrypted and verified claims merged, with private claims taking precedence.

  Function unsign-jws will only validate and return public claims. This is for callers
  that only have the public key to verify public claims but don't have the shared secret
  to access private data.

  The keys to use are specified by keyring parameter, which is a vector or key.

  Each key needs to specify ':alg', ':kid' (key ID), and either ':secret' for symmetrical key or
  ':public-key' and optionally ':private-key' for EdDSA key.

  When creating JWT the first EdDSA or symmetric key in keyring will be used (as needed). When
  doing the unpacking, 'kid' claim will be used to select a key from the keyring.

  See codescene.crypto.keygen for key generation."
  (:require [buddy.core.codecs :as codecs]
            [buddy.core.keys :as keys]
            [buddy.sign.jwt :as jwt]
            [clj-time.coerce :as tcc]
            [clj-time.core :as tc]
            [medley.core :as m]
            [codescene.util.json :as json])
  (:import (java.nio.charset StandardCharsets)))

(defn init-keyring
  "Convert Base64 keys in keyring to actual Private and Public keys"
  [keyring]
  (mapv #(-> %
             (m/update-existing :private-key (fn [k] (if (string? k) (keys/str->private-key k) k)))
             (m/update-existing :public-key (fn [k] (if (string? k) (keys/str->public-key k) k))))
        keyring))

(defn base64-decode-object [s]
  (let [^bytes b (codecs/b64->bytes s true)]
    (json/parse-string (String. b StandardCharsets/UTF_8))))

(defn extract-jwt [jwt]
  (when-let [[_ header payload _sig] (re-find #"([^.]+)\.([^.]+)\.([^.]+)" jwt)]
    {:jwt jwt
     :header (base64-decode-object header)
     :payload (base64-decode-object payload)}))

(defn find-key
  "Finds a key in keyring that fits the submap."
  [keyring key-desc]
  (let [ks (keys key-desc)]
    (m/find-first #(= key-desc (select-keys % ks)) keyring)))

(defn- correct-keytype? [typ alg]
  (or (and (= "JWS" typ) (= alg :eddsa))
      (and (= "JWE" typ) (= alg :dir))))

(defn validated-header-key
  "Validate unsign or decrypt key."
  [keyring {:keys [kid typ] :as header}]
  (let [th #(throw (ex-info %1 {:type :validation :cause :keys}))
        {:keys [alg] :as k} (find-key keyring (select-keys header [:alg :enc :kid]))]
    (cond
      (not kid) (th "No 'kid' in JWT")
      (correct-keytype? typ alg)
      (or k (th (format "No valid key found for typ=%s kid=%s alg=%s" typ kid alg)))
      :else (th (format "Key kid=%s is of wrong for typ=%s" kid typ)))))

(defn eddsa-signed
  "Create EdDSA signed JWS"
  [claims privk header]
  (jwt/sign claims privk (m/assoc-some {:alg :eddsa} :header header)))

(defn eddsa-unsigned [jwt pubk] (jwt/unsign jwt pubk {:alg :eddsa}))

(defn jwe-encrypt
  "Create symmetric key JWE. This is suitable when the same entity is creating
  and consuming the tokens. Will handle Base64Url keys. Key should be random 32 bytes."
  [claims kid secret]
  (let [k (if (string? secret) (codecs/b64->bytes secret true) secret)]
    ;; JWE puts header in AAD so we can trust KID
    (jwt/encrypt claims k {:alg :dir :enc :a128cbc-hs256 :header {:kid kid :typ "JWE"}})))

(defn claims
  "Standard claims."
  [duration leeway issuer aud]
  (m/assoc-some {:iss issuer
                 :iat (tcc/to-epoch (tc/now))}
                :exp (some->> duration (tc/plus (tc/now)) tcc/to-epoch)
                :aud aud
                ;; leeway is how many seconds after expiry the validation still works
                :leeway (when leeway (tc/in-seconds leeway))))

(defn make-jwt
  "Make a JWS or JWE or JWS(JWE(claims)) JWT.

  Keyring is a map of KID (key ID) to a submap that describes the key."
  [public-claims private-claims keyring]
  (let [{jwe-kid :kid jwe-key :secret} (find-key keyring {:alg :dir})
        jwe (when (not-empty private-claims)
              (jwe-encrypt private-claims jwe-kid jwe-key))]
    (if (empty? public-claims)
      jwe
      (let [{jws-kid :kid jws-key :private-key} (find-key keyring {:alg :eddsa})
            claims (cond-> public-claims
                     jwe (assoc :? jwe))]
        (eddsa-signed claims jws-key {:kid jws-kid :typ "JWS"})))))

(defn th [s] (throw (ex-info s {:type :validation :cause :jwt-type})))

(defn unsign-jws
  "Validates signed JWS token and returns the claims. Throws an error if not a JWS token.

  This is useful instead of unpack-jwt when you know there's encrypted content in JWT and
  you don't have the key for it or you don't care about the encrypted content.

  This function will not attempt to decrypt it, and will return only the public, signed claims."
  [jwt keyring]
  (let [{:keys [typ] :as header} (jwt/decode-header jwt)]
    (if (= typ "JWS")
      (eddsa-unsigned jwt (:public-key (validated-header-key keyring header)))
      (th (format "The token is not a JWS token, typ=%s" typ)))))

(defn- decrypt-jwe
  [jwt keyring]
  (let [{:keys [typ] :as header} (jwt/decode-header jwt)]
    (if (= typ "JWE")
      (let [{:keys [secret]} (validated-header-key keyring header)
            k (if (string? secret) (codecs/b64->bytes secret true) secret)]
        ;; JWE puts header in AAD so we can trust KID
        (jwt/decrypt jwt k {:alg :dir :enc :a128cbc-hs256}))
      (th (format "The token is not a JWE token, typ=%s" typ)))))

(defn unpack-jwt
  "Unpacks JWS/JWE JWT

  Keyring is a list of key specs."
  [jwt keyring]
  (let [{:keys [typ] :as header} (jwt/decode-header jwt)]
    (case typ
      "JWE" (decrypt-jwe jwt keyring)
      "JWS" (let [{:keys [?] :as ret} (eddsa-unsigned jwt (:public-key (validated-header-key keyring header)))
                  private (when ? (decrypt-jwe ? keyring))]
              (merge (dissoc ret :?) private))
      (th (format "Unknown typ=%s" typ)))))