/*
 * Decompiled with CFR 0.152.
 */
package org.keycloak.crypto.fips;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.math.BigInteger;
import java.nio.charset.Charset;
import java.security.InvalidAlgorithmParameterException;
import java.security.Key;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.security.spec.ECParameterSpec;
import java.security.spec.ECPoint;
import java.security.spec.ECPublicKeySpec;
import java.security.spec.InvalidKeySpecException;
import org.bouncycastle.asn1.nist.NISTNamedCurves;
import org.bouncycastle.asn1.x9.X9ECParameters;
import org.bouncycastle.crypto.Algorithm;
import org.bouncycastle.crypto.AsymmetricPrivateKey;
import org.bouncycastle.crypto.AsymmetricPublicKey;
import org.bouncycastle.crypto.Parameters;
import org.bouncycastle.crypto.SymmetricKey;
import org.bouncycastle.crypto.SymmetricSecretKey;
import org.bouncycastle.crypto.asymmetric.AsymmetricECPrivateKey;
import org.bouncycastle.crypto.asymmetric.AsymmetricECPublicKey;
import org.bouncycastle.crypto.asymmetric.ECDomainParameters;
import org.bouncycastle.crypto.fips.FipsAES;
import org.bouncycastle.crypto.fips.FipsAgreement;
import org.bouncycastle.crypto.fips.FipsEC;
import org.bouncycastle.crypto.fips.FipsKDF;
import org.bouncycastle.crypto.fips.FipsKeyUnwrapper;
import org.bouncycastle.crypto.fips.FipsKeyWrapper;
import org.bouncycastle.jcajce.spec.ECDomainParameterSpec;
import org.keycloak.common.util.Base64Url;
import org.keycloak.jose.jwe.JWEHeader;
import org.keycloak.jose.jwe.JWEKeyStorage;
import org.keycloak.jose.jwe.alg.JWEAlgorithmProvider;
import org.keycloak.jose.jwe.enc.JWEEncryptionProvider;
import org.keycloak.jose.jwk.ECPublicJWK;
import org.keycloak.jose.jwk.JWKUtil;

public class BCFIPSEcdhEsAlgorithmProvider
implements JWEAlgorithmProvider {
    public byte[] decodeCek(byte[] encodedCek, Key encryptionKey, JWEHeader header, JWEEncryptionProvider encryptionProvider) throws Exception {
        int keyDataLength = BCFIPSEcdhEsAlgorithmProvider.getKeyDataLength(header.getAlgorithm(), encryptionProvider);
        PublicKey sharedPublicKey = BCFIPSEcdhEsAlgorithmProvider.toPublicKey(header.getEphemeralPublicKey());
        String algorithmID = BCFIPSEcdhEsAlgorithmProvider.getAlgorithmID(header.getAlgorithm(), header.getEncryptionAlgorithm());
        byte[] derivedKey = BCFIPSEcdhEsAlgorithmProvider.deriveKey(sharedPublicKey, encryptionKey, keyDataLength, algorithmID, this.base64UrlDecode(header.getAgreementPartyUInfo()), this.base64UrlDecode(header.getAgreementPartyVInfo()));
        if ("ECDH-ES".equals(header.getAlgorithm())) {
            return derivedKey;
        }
        SymmetricSecretKey aesKey = new SymmetricSecretKey((Parameters)FipsAES.KW, derivedKey);
        FipsAES.KeyWrapOperatorFactory factory = new FipsAES.KeyWrapOperatorFactory();
        FipsKeyUnwrapper unwrapper = factory.createKeyUnwrapper((SymmetricKey)aesKey, FipsAES.KW);
        return unwrapper.unwrap(encodedCek, 0, encodedCek.length);
    }

    public byte[] encodeCek(JWEEncryptionProvider encryptionProvider, JWEKeyStorage keyStorage, Key encryptionKey, JWEHeader.JWEHeaderBuilder headerBuilder) throws Exception {
        JWEHeader header = headerBuilder.build();
        int keyDataLength = BCFIPSEcdhEsAlgorithmProvider.getKeyDataLength(header.getAlgorithm(), encryptionProvider);
        ECParameterSpec params = ((ECPublicKey)encryptionKey).getParams();
        KeyPair ephemeralKeyPair = BCFIPSEcdhEsAlgorithmProvider.generateEcKeyPair(params);
        ECPublicKey ephemeralPublicKey = (ECPublicKey)ephemeralKeyPair.getPublic();
        ECPrivateKey ephemeralPrivateKey = (ECPrivateKey)ephemeralKeyPair.getPrivate();
        byte[] agreementPartyUInfo = header.getAgreementPartyUInfo() != null ? this.base64UrlDecode(header.getAgreementPartyUInfo()) : new byte[]{};
        byte[] agreementPartyVInfo = header.getAgreementPartyVInfo() != null ? this.base64UrlDecode(header.getAgreementPartyVInfo()) : new byte[]{};
        headerBuilder.ephemeralPublicKey(BCFIPSEcdhEsAlgorithmProvider.toECPublicJWK(ephemeralPublicKey));
        String algorithmID = BCFIPSEcdhEsAlgorithmProvider.getAlgorithmID(header.getAlgorithm(), header.getEncryptionAlgorithm());
        byte[] derivedKey = BCFIPSEcdhEsAlgorithmProvider.deriveKey(encryptionKey, ephemeralPrivateKey, keyDataLength, algorithmID, agreementPartyUInfo, agreementPartyVInfo);
        if ("ECDH-ES".equals(header.getAlgorithm())) {
            keyStorage.setCEKBytes(derivedKey);
            encryptionProvider.deserializeCEK(keyStorage);
            return new byte[0];
        }
        byte[] inputKeyBytes = keyStorage.getCekBytes();
        byte[] keyBytes = derivedKey;
        SymmetricSecretKey aesKey = new SymmetricSecretKey((Parameters)FipsAES.KW, keyBytes);
        FipsAES.KeyWrapOperatorFactory factory = new FipsAES.KeyWrapOperatorFactory();
        FipsKeyWrapper wrapper = factory.createKeyWrapper((SymmetricKey)aesKey, FipsAES.KW);
        return wrapper.wrap(inputKeyBytes, 0, inputKeyBytes.length);
    }

    private byte[] base64UrlDecode(String encoded) {
        return Base64Url.decode((String)(encoded == null ? "" : encoded));
    }

    private static KeyPair generateEcKeyPair(ECParameterSpec params) {
        try {
            KeyPairGenerator keyGen = KeyPairGenerator.getInstance("EC", "BCFIPS");
            SecureRandom randomGen = SecureRandom.getInstance("DEFAULT", "BCFIPS");
            keyGen.initialize(params, randomGen);
            return keyGen.generateKeyPair();
        }
        catch (InvalidAlgorithmParameterException | NoSuchAlgorithmException | NoSuchProviderException e) {
            throw new IllegalArgumentException(e);
        }
    }

    private static byte[] deriveOtherInfo(int keyDataLength, String algorithmID, byte[] agreementPartyUInfo, byte[] agreementPartyVInfo) {
        byte[] algorithmId = BCFIPSEcdhEsAlgorithmProvider.encodeDataLengthData(algorithmID.getBytes(Charset.forName("ASCII")));
        byte[] partyUInfo = BCFIPSEcdhEsAlgorithmProvider.encodeDataLengthData(agreementPartyUInfo);
        byte[] partyVInfo = BCFIPSEcdhEsAlgorithmProvider.encodeDataLengthData(agreementPartyVInfo);
        byte[] suppPubInfo = BCFIPSEcdhEsAlgorithmProvider.toByteArray(keyDataLength);
        byte[] suppPrivInfo = BCFIPSEcdhEsAlgorithmProvider.emptyBytes();
        return BCFIPSEcdhEsAlgorithmProvider.concat(algorithmId, partyUInfo, partyVInfo, suppPubInfo, suppPrivInfo);
    }

    public static byte[] deriveKey(Key publicKey, Key privateKey, int keyDataLength, String algorithmID, byte[] agreementPartyUInfo, byte[] agreementPartyVInfo) {
        byte[] otherInfo = BCFIPSEcdhEsAlgorithmProvider.deriveOtherInfo(keyDataLength, algorithmID, agreementPartyUInfo, agreementPartyVInfo);
        FipsEC.DHAgreementFactory factory = new FipsEC.DHAgreementFactory();
        FipsAgreement agree = factory.createAgreement((AsymmetricPrivateKey)new AsymmetricECPrivateKey((Algorithm)FipsEC.ALGORITHM, privateKey.getEncoded()), FipsEC.DH.withKDF(FipsKDF.CONCATENATION.withPRF(FipsKDF.AgreementKDFPRF.SHA256), otherInfo, keyDataLength / 8));
        return agree.calculate((AsymmetricPublicKey)new AsymmetricECPublicKey((Algorithm)FipsEC.ALGORITHM, publicKey.getEncoded()));
    }

    private static ECPublicJWK toECPublicJWK(ECPublicKey ecKey) {
        ECPublicJWK k = new ECPublicJWK();
        int fieldSize = ecKey.getParams().getCurve().getField().getFieldSize();
        k.setCrv("P-" + fieldSize);
        k.setKeyType("EC");
        k.setX(Base64Url.encode((byte[])JWKUtil.toIntegerBytes((BigInteger)ecKey.getW().getAffineX(), (int)fieldSize)));
        k.setY(Base64Url.encode((byte[])JWKUtil.toIntegerBytes((BigInteger)ecKey.getW().getAffineY(), (int)fieldSize)));
        return k;
    }

    private static PublicKey toPublicKey(ECPublicJWK jwk) {
        String crv = jwk.getCrv();
        String xStr = jwk.getX();
        String yStr = jwk.getY();
        if (crv == null) {
            throw new IllegalArgumentException("JWK crv must be set");
        }
        if (xStr == null) {
            throw new IllegalArgumentException("JWK x must be set");
        }
        if (yStr == null) {
            throw new IllegalArgumentException("JWK y must be set");
        }
        BigInteger x = new BigInteger(1, Base64Url.decode((String)xStr));
        BigInteger y = new BigInteger(1, Base64Url.decode((String)yStr));
        try {
            ECPoint point = new ECPoint(x, y);
            X9ECParameters ecParams = NISTNamedCurves.getByName((String)crv);
            ECDomainParameterSpec params = new ECDomainParameterSpec(new ECDomainParameters(ecParams.getCurve(), ecParams.getG(), ecParams.getN(), ecParams.getH()));
            ECPublicKeySpec pubKeySpec = new ECPublicKeySpec(point, (ECParameterSpec)params);
            KeyFactory keyFactory = KeyFactory.getInstance("EC", "BCFIPS");
            return keyFactory.generatePublic(pubKeySpec);
        }
        catch (NoSuchAlgorithmException | NoSuchProviderException | InvalidKeySpecException e) {
            throw new IllegalArgumentException(e);
        }
    }

    private static String getAlgorithmID(String alg, String enc) {
        if ("ECDH-ES+A128KW".equals(alg) || "ECDH-ES+A192KW".equals(alg) || "ECDH-ES+A256KW".equals(alg)) {
            return alg;
        }
        if ("ECDH-ES".equals(alg)) {
            return enc;
        }
        throw new IllegalArgumentException("Unsupported algorithm");
    }

    private static int getKeyDataLength(String alg, JWEEncryptionProvider encryptionProvider) {
        if ("ECDH-ES+A128KW".equals(alg)) {
            return 128;
        }
        if ("ECDH-ES+A192KW".equals(alg)) {
            return 192;
        }
        if ("ECDH-ES+A256KW".equals(alg)) {
            return 256;
        }
        if ("ECDH-ES".equals(alg)) {
            return encryptionProvider.getExpectedCEKLength() * 8;
        }
        throw new IllegalArgumentException("Unsupported algorithm");
    }

    private static byte[] encodeDataLengthData(byte[] data) {
        byte[] databytes = data != null ? data : new byte[]{};
        byte[] datalen = BCFIPSEcdhEsAlgorithmProvider.toByteArray(databytes.length);
        return BCFIPSEcdhEsAlgorithmProvider.concat(datalen, databytes);
    }

    private static byte[] emptyBytes() {
        return new byte[0];
    }

    private static byte[] toByteArray(int intValue) {
        return new byte[]{(byte)(intValue >> 24), (byte)(intValue >> 16), (byte)(intValue >> 8), (byte)intValue};
    }

    private static byte[] concat(byte[] ... byteArrays) {
        byte[] throwable;
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        try {
            for (byte[] bytes : byteArrays) {
                if (bytes == null) continue;
                baos.write(bytes);
            }
            throwable = baos.toByteArray();
        }
        catch (Throwable throwable2) {
            try {
                try {
                    baos.close();
                }
                catch (Throwable throwable3) {
                    throwable2.addSuppressed(throwable3);
                }
                throw throwable2;
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        }
        baos.close();
        return throwable;
    }
}

