import { CompactEncrypt, importJWK } from "jose";
import PropTypes from "prop-types";
import React, { createContext, useEffect, useRef, useState } from "react";
import { ValidationError } from "yup";

import { swaDate } from "@swa-ui/date";
import { usePersistedState } from "@swa-ui/persistence";

import {
  ENCRYPTION_PROVIDER_COMPONENT_NAME,
  ENCRYPTION_TYPE,
  MAXIMUM_NUMBER_OF_ENCRYPTION_ATTEMPTS,
  MAXIMUM_NUMBER_OF_RETRIES,
} from "../defines/constants";
import {
  API_FETCH_ERROR,
  DATA_ENCRYPTION_FAILED,
  DATA_SCHEMA_MISMATCH,
  KEY_CRITERIA_MISMATCH,
  KEY_RETRIEVAL_FAILED,
  KEY_WILL_EXPIRE_SOON,
  NO_KEY_CONFIG_MAP_OR_CONFIGS_EXIST,
  NO_SCHEMA_PROVIDED,
} from "../defines/loggerMessages";
import { STATE_SETTINGS } from "../defines/stateSettings";
import { MILLISECONDS_IN_A_SECOND, SECONDS_IN_A_HOUR } from "../defines/timeConstants";
import { convertToMilliSeconds } from "../utils/convertToMilliSeconds";
import { removeDuplicateValues, removeExtraKeys, removeObjectKey } from "../utils/modifyObjects";

export const EncryptionContext = createContext({
  encrypt: () => "",
});

export const EncryptionProvider = (props) => {
  const { apiKey = "", channelId, jwtSettings = {}, logger, schemas } = props;
  const { configs = [], keyConfigMap = {} } = jwtSettings;
  const [jwksState, setJwksState] = usePersistedState(STATE_SETTINGS.jwksState);
  const [refreshEncryptionToken, setRefreshEncryptionToken] = useState(swaDate());
  const [clearRefreshInterval, setClearRefreshInterval] = useState(false);
  const numberOfFailedEncryptionAttempts = useRef(0);
  const numberOfRetries = useRef({});
  const modifiedKeyConfigMap = removeDuplicateValues(keyConfigMap);

  useEffect(() => {
    const intervals = [];

    if (
      configs.length &&
      Object.keys(modifiedKeyConfigMap).length &&
      checkIfAnyConfigIsEnabled(configs)
    ) {
      const publicKeyPromises = Object.values(modifiedKeyConfigMap).map(
        async (jwksConfigMapType) => {
          const jwtConfig = jwtSettings.configs.find((config) => config.id === jwksConfigMapType);
          const { jwksTtlInSeconds } = jwtConfig;
          let publicKeyToStore = {};

          const fetchedPublicKeys = await refreshPublicKey(jwksConfigMapType, jwtConfig);
          const freshPublicKey = validatePublicKeys(
            jwksConfigMapType,
            fetchedPublicKeys?.keys ?? []
          );

          const refreshInterval = setInterval(async () => {
            const fetchedRefreshIntervalPublicKeys = await refreshPublicKey(
              jwksConfigMapType,
              jwtConfig
            );
            const freshRefreshIntervalPublicKey = validatePublicKeys(
              jwksConfigMapType,
              fetchedRefreshIntervalPublicKeys?.keys ?? []
            );

            if (freshRefreshIntervalPublicKey) {
              publicKeyToStore = createPublicKeyToStore(
                freshRefreshIntervalPublicKey,
                jwksConfigMapType,
                jwksTtlInSeconds
              );
            }

            setPublicKeyInSession([publicKeyToStore]);
          }, jwksTtlInSeconds * MILLISECONDS_IN_A_SECOND);

          intervals.push(refreshInterval);

          if (freshPublicKey) {
            publicKeyToStore = createPublicKeyToStore(
              freshPublicKey,
              jwksConfigMapType,
              jwksTtlInSeconds
            );
          }

          return Promise.resolve(publicKeyToStore);
        }
      );

      Promise.all(publicKeyPromises).then((publicKeys) => {
        const allPublicKeys = Object.assign({}, ...publicKeys);

        setJwksState(allPublicKeys);
      });
    } else {
      sessionStorage.removeItem(STATE_SETTINGS.jwksState.key);

      if (!configs.length || !Object.keys(modifiedKeyConfigMap).length) {
        logger.error(
          NO_KEY_CONFIG_MAP_OR_CONFIGS_EXIST,
          ENCRYPTION_PROVIDER_COMPONENT_NAME,
          jwtSettings
        );
      }
    }

    return () => intervals.forEach((interval) => clearInterval(interval));
  }, []);

  useEffect(() => {
    let refreshEncryptionTokenInterval;

    if (
      configs.length &&
      Object.keys(modifiedKeyConfigMap).length &&
      checkIfAnyConfigIsEnabled(configs) &&
      !clearRefreshInterval
    ) {
      refreshEncryptionTokenInterval = setInterval(() => {
        setRefreshEncryptionToken(swaDate());
      }, convertToMilliSeconds(jwtSettings?.jweTokenRefreshTtlInSeconds));
    } else {
      clearInterval(refreshEncryptionTokenInterval);
    }

    return () => {
      clearInterval(refreshEncryptionTokenInterval);
    };
  }, [clearRefreshInterval]);

  return (
    <EncryptionContext.Provider value={{ encrypt, refreshEncryptionToken }}>
      {props.children}
    </EncryptionContext.Provider>
  );

  function setPublicKeyInSession(publicKeys) {
    const allPublicKeys = Object.assign({}, ...publicKeys);

    return setJwksState((prevState) => ({ ...prevState, ...allPublicKeys }));
  }

  function createPublicKeyToStore(freshPublicKey, jwksConfigMapType, jwksTtlInSeconds) {
    const publicKeyToStore = {
      [jwksConfigMapType]: {
        publicKey: {
          ...freshPublicKey,
        },
        publicKeyTimeToLiveTime: now() + jwksTtlInSeconds,
      },
    };

    return publicKeyToStore;
  }

  function checkIfAnyConfigIsEnabled(data) {
    return data.some((config) => config.enable === true);
  }

  async function getPublicKey(jwksConfigMapType, jwtConfig) {
    const { endpointUrl } = jwtConfig;

    try {
      const response = await fetch(endpointUrl, {
        headers: {
          "x-api-key": apiKey,
          "x-channel-id": channelId,
          Accept: "application/json",
        },
      });

      if (!response.ok) {
        const error = new Error(response.statusText);

        error.response = response;
        throw error;
      } else {
        const data = await response.json();

        if (Object.keys(data).length === 0) {
          setJwksRetryTimer(jwksConfigMapType, jwtConfig);
        }

        return data;
      }
    } catch (error) {
      const existingKey = useExistingPublicKey(jwksConfigMapType);

      if (error?.response) {
        const { response } = error;

        logger.error(
          API_FETCH_ERROR,
          ENCRYPTION_PROVIDER_COMPONENT_NAME,
          response,
          response?.status,
          endpointUrl
        );
      } else {
        logger.error(KEY_RETRIEVAL_FAILED, ENCRYPTION_PROVIDER_COMPONENT_NAME, error);
      }

      handleJwksRetry(jwksConfigMapType, jwtConfig);

      return existingKey;
    }
  }

  async function refreshPublicKey(jwksConfigMapType, jwtConfig) {
    const { enable } = jwtConfig;
    const configState = jwksState[jwksConfigMapType];
    const { publicKey = {}, publicKeyTimeToLiveTime } = configState ?? {};
    let modifiedJwksState;
    let result = Promise.resolve(publicKey);

    modifiedJwksState = removeExtraKeys(modifiedKeyConfigMap, jwksState);

    if (!enable) {
      modifiedJwksState = removeObjectKey(jwksState, jwksConfigMapType);
    }

    setJwksState(modifiedJwksState);

    if ((!publicKey || isNewKeyRequired(publicKey, publicKeyTimeToLiveTime)) && enable) {
      result = getPublicKey(jwksConfigMapType, jwtConfig);
    } else if (publicKeyTimeToLiveTime > publicKey.expires_on) {
      logger.error(
        KEY_WILL_EXPIRE_SOON,
        ENCRYPTION_PROVIDER_COMPONENT_NAME,
        jwksState[jwtConfig.id].publicKey.expires_on
      );
      const { expires_on } = jwksState[jwtConfig?.id]?.publicKey;
      const oneHourBeforeExpiration = convertToMilliSeconds(expires_on - SECONDS_IN_A_HOUR);
      const randomTimeToRetry = convertToMilliSeconds(
        Math.floor(Math.random() * (expires_on - now() + 1))
      );

      if (
        oneHourBeforeExpiration - convertToMilliSeconds(jwtConfig.jwksRandomRetryMinimumInSeconds) <
        randomTimeToRetry
      ) {
        result = useExistingPublicKey(jwksConfigMapType);
      } else {
        result = new Promise((resolve) => {
          setTimeout(() => resolve(getPublicKey(jwksConfigMapType, jwtConfig)), randomTimeToRetry);
        });
      }
    }

    return result;
  }

  function setJwksRetryTimer(jwksConfigMapType, jwtConfig) {
    const { jwksRetryTtlInSeconds } = jwtConfig;

    return setTimeout(() => {
      refreshPublicKey(jwksConfigMapType, jwtConfig);
    }, convertToMilliSeconds(jwksRetryTtlInSeconds));
  }

  function useExistingPublicKey(keyType) {
    const { publicKey = {} } = jwksState[keyType] ?? {};

    return isValidKey(publicKey) ? publicKey : undefined;
  }

  function validatePublicKeys(keyType, keys) {
    const encryptionKeys = keys.filter(isValidKey);
    let publicKeyToUse;

    if (encryptionKeys.length) {
      publicKeyToUse = encryptionKeys.reduce((prev, current) =>
        prev.expires_on > current.expires_on ? prev : current
      );
    } else {
      const existingKey = useExistingPublicKey(keyType);

      if (existingKey) {
        publicKeyToUse = existingKey;
      }
    }

    return publicKeyToUse;
  }

  function isValidKey(key) {
    return key && key.expires_on > now() && key.use === "enc";
  }

  function isNewKeyRequired(publicKeyToCheck, publicKeyTimeToLiveTimeToCheck) {
    return (
      !publicKeyToCheck ||
      now() > publicKeyToCheck.expires_on ||
      now() > publicKeyTimeToLiveTimeToCheck ||
      Object.keys(publicKeyToCheck).length === 0
    );
  }

  function now() {
    return Math.floor(swaDate() / MILLISECONDS_IN_A_SECOND);
  }

  async function encrypt(dataToEncrypt, component, jwksFeature) {
    const jwtConfigType = keyConfigMap[jwksFeature] || {};
    let encryptedData = "";

    if (Object.keys(keyConfigMap).length) {
      try {
        const { publicKey } = jwksState[jwtConfigType];

        if (isValidKey(publicKey)) {
          if (schemas) {
            const formData = schemas.validateSync(dataToEncrypt);

            encryptedData = await encryptData(formData, publicKey, jwtConfigType);
          } else {
            handleEncryptionErrors("warn", NO_SCHEMA_PROVIDED, component, {});
          }
        } else {
          handleEncryptionErrors("error", KEY_CRITERIA_MISMATCH, component, {});
        }
      } catch (error) {
        if (error instanceof ValidationError) {
          handleEncryptionErrors("error", DATA_SCHEMA_MISMATCH, component, error);
        } else {
          handleEncryptionErrors("error", DATA_ENCRYPTION_FAILED, component, error);
        }
      }
    }

    return encryptedData;
  }

  async function encryptData(data, publicKey, jwtConfigType) {
    const jwkConfigObject = jwtSettings.configs.find((config) => config.id === jwtConfigType);
    const {
      createJWTForPayload = false,
      createPayloadWithExactAppData = false,
      enc,
      iss,
      jweTokenTtlInSeconds,
    } = jwkConfigObject;
    const formatData = {
      exp: swaDate().valueOf() + convertToMilliSeconds(jweTokenTtlInSeconds),
      iat: now(),
      iss,
      ...getPayloadData(data, createPayloadWithExactAppData),
    };
    const payloadForJWE = createJWTForPayload
      ? generateUnsignedJWT(formatData)
      : JSON.stringify(formatData);
    const rsaPublicKey = await importJWK(
      {
        kty: publicKey.kty,
        kid: publicKey.kid,
        e: publicKey.e,
        n: publicKey.n,
        x5c: publicKey.x5c,
        x5t: publicKey.x5t,
      },
      publicKey.alg
    );
    const jwe = await new CompactEncrypt(new TextEncoder().encode(payloadForJWE))
      .setProtectedHeader({
        alg: publicKey.alg,
        enc,
        kid: publicKey.kid,
        typ: ENCRYPTION_TYPE,
      })
      .encrypt(rsaPublicKey);

    return jwe;
  }

  function getPayloadData(payloadData, createPayloadWithExactAppData) {
    const nestedPayloadData = {
      request: {
        ...payloadData,
      },
    };

    return createPayloadWithExactAppData ? payloadData : nestedPayloadData;
  }

  function generateUnsignedJWT(dataToTokenize) {
    const jwtHeader = {
      alg: "none",
      typ: "JWT",
    };
    const encodedHeader = convertStringToBase64(JSON.stringify(jwtHeader));
    const encodedPayload = convertStringToBase64(JSON.stringify(dataToTokenize));
    const unsignedJWT = `${encodedHeader}.${encodedPayload}.`;

    return unsignedJWT;
  }

  function convertStringToBase64(string) {
    return btoa(string);
  }

  function handleEncryptionErrors(type, ...rest) {
    if (!clearRefreshInterval) {
      numberOfFailedEncryptionAttempts.current = numberOfFailedEncryptionAttempts.current + 1;

      if (MAXIMUM_NUMBER_OF_ENCRYPTION_ATTEMPTS >= numberOfFailedEncryptionAttempts.current) {
        logger[type](...rest);
      } else {
        setClearRefreshInterval(true);
      }
    }
  }

  function handleJwksRetry(jwksConfigMapType, jwtConfig) {
    const retriesForType = numberOfRetries.current[jwksConfigMapType] ?? 0;

    if (MAXIMUM_NUMBER_OF_RETRIES > retriesForType) {
      numberOfRetries.current[jwksConfigMapType] = retriesForType + 1;

      setJwksRetryTimer(jwksConfigMapType, jwtConfig);
    }
  }
};

EncryptionProvider.propTypes = {
  /** API key for authentication. */
  apiKey: PropTypes.string.isRequired,

  /** Identifier for the channel. */
  channelId: PropTypes.string.isRequired,

  /** Content to be rendered on the page. */
  children: PropTypes.node.isRequired,

  /** JWT settings configuration. */
  jwtSettings: PropTypes.shape({
    /** Array of configuration objects. */
    configs: PropTypes.arrayOf(
      PropTypes.shape({
        /** Timeout in milliseconds for the JWS call. */
        callTimeoutInMillis: PropTypes.number.isRequired,

        /** Flag to generate JWT for payload. */
        createJWTForPayload: PropTypes.bool,

        /** Flag to use payload data as-is. */
        createPayloadWithExactAppData: PropTypes.bool,

        /** Encryption algorithm. */
        enc: PropTypes.string.isRequired,

        /** Flag to enable or disable encryption for the feature. */
        enable: PropTypes.bool.isRequired,

        /** Endpoint URL for the security configuration. */
        endpointUrl: PropTypes.string.isRequired,

        /** Identifier for the JWKS key. */
        id: PropTypes.string.isRequired,

        /** Minimum random retry time for JWKS call in seconds. */
        jwksRandomRetryMinimumInSeconds: PropTypes.number.isRequired,

        /** Retry time-to-live for JWKS in seconds. */
        jwksRetryTtlInSeconds: PropTypes.number.isRequired,

        /** Time-to-live for JWE token in seconds. */
        jweTokenTtlInSeconds: PropTypes.number.isRequired,

        /** Time-to-live for JWE token refresh in seconds. */
        jweTokenRefreshTtlInSeconds: PropTypes.number.isRequired,

        /** Time-to-live for JWKS in seconds. */
        jwksTtlInSeconds: PropTypes.number.isRequired,
      })
    ).isRequired,

    /** Map of feature keys to JWKS key identifiers. */
    keyConfigMap: PropTypes.objectOf(PropTypes.string).isRequired,

    /** Time-to-live for JWE token refresh in seconds. */
    jweTokenRefreshTtlInSeconds: PropTypes.number.isRequired,
  }).isRequired,

  /** Logger object for error and warning logs. */
  logger: PropTypes.shape({
    /** Function to log errors. */
    error: PropTypes.func.isRequired,

    /** Function to log warnings. */
    warn: PropTypes.func.isRequired,
  }).isRequired,

  /** Schema object for validation. */
  schemas: PropTypes.oneOfType([PropTypes.object, PropTypes.oneOf([null, undefined])]),
};
