/*
 * Decompiled with CFR 0.152.
 */
package net.zaiyers.UUIDDB.lib.mongodb.internal.connection;

import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.HashMap;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import net.zaiyers.UUIDDB.lib.bson.BsonBoolean;
import net.zaiyers.UUIDDB.lib.bson.BsonDocument;
import net.zaiyers.UUIDDB.lib.bson.BsonString;
import net.zaiyers.UUIDDB.lib.mongodb.AuthenticationMechanism;
import net.zaiyers.UUIDDB.lib.mongodb.MongoCredential;
import net.zaiyers.UUIDDB.lib.mongodb.ServerAddress;
import net.zaiyers.UUIDDB.lib.mongodb.ServerApi;
import net.zaiyers.UUIDDB.lib.mongodb.assertions.Assertions;
import net.zaiyers.UUIDDB.lib.mongodb.connection.ClusterConnectionMode;
import net.zaiyers.UUIDDB.lib.mongodb.internal.authentication.NativeAuthenticationHelper;
import net.zaiyers.UUIDDB.lib.mongodb.internal.authentication.SaslPrep;
import net.zaiyers.UUIDDB.lib.mongodb.internal.connection.InternalConnection;
import net.zaiyers.UUIDDB.lib.mongodb.internal.connection.MongoCredentialWithCache;
import net.zaiyers.UUIDDB.lib.mongodb.internal.connection.SaslAuthenticator;
import net.zaiyers.UUIDDB.lib.mongodb.lang.Nullable;

class ScramShaAuthenticator
extends SaslAuthenticator {
    private final RandomStringGenerator randomStringGenerator;
    private final AuthenticationHashGenerator authenticationHashGenerator;
    private SaslClient speculativeSaslClient;
    private BsonDocument speculativeAuthenticateResponse;
    private static final int MINIMUM_ITERATION_COUNT = 4096;
    private static final String GS2_HEADER = "n,,";
    private static final int RANDOM_LENGTH = 24;
    private static final byte[] INT_1 = new byte[]{0, 0, 0, 1};
    private static final AuthenticationHashGenerator DEFAULT_AUTHENTICATION_HASH_GENERATOR = credential -> {
        char[] password = credential.getPassword();
        if (password == null) {
            throw new IllegalArgumentException("Password must not be null");
        }
        return new String(password);
    };
    private static final AuthenticationHashGenerator LEGACY_AUTHENTICATION_HASH_GENERATOR = credential -> {
        String username = credential.getUserName();
        char[] password = credential.getPassword();
        if (username == null || password == null) {
            throw new IllegalArgumentException("Username and password must not be null");
        }
        return NativeAuthenticationHelper.createAuthenticationHash(username, password);
    };

    ScramShaAuthenticator(MongoCredentialWithCache credential, ClusterConnectionMode clusterConnectionMode, @Nullable ServerApi serverApi) {
        this(credential, new DefaultRandomStringGenerator(), ScramShaAuthenticator.getAuthenicationHashGenerator(Assertions.assertNotNull(credential.getAuthenticationMechanism())), clusterConnectionMode, serverApi);
    }

    ScramShaAuthenticator(MongoCredentialWithCache credential, RandomStringGenerator randomStringGenerator, AuthenticationHashGenerator authenticationHashGenerator, ClusterConnectionMode clusterConnectionMode, @Nullable ServerApi serverApi) {
        super(credential, clusterConnectionMode, serverApi);
        this.randomStringGenerator = randomStringGenerator;
        this.authenticationHashGenerator = authenticationHashGenerator;
    }

    @Override
    public String getMechanismName() {
        AuthenticationMechanism authMechanism = this.getMongoCredential().getAuthenticationMechanism();
        if (authMechanism == null) {
            throw new IllegalArgumentException("Authentication mechanism cannot be null");
        }
        return authMechanism.getMechanismName();
    }

    @Override
    protected void appendSaslStartOptions(BsonDocument saslStartCommand) {
        saslStartCommand.append("options", new BsonDocument("skipEmptyExchange", new BsonBoolean(true)));
    }

    @Override
    protected SaslClient createSaslClient(ServerAddress serverAddress) {
        if (this.speculativeSaslClient != null) {
            return this.speculativeSaslClient;
        }
        return new ScramShaSaslClient(this.getMongoCredentialWithCache(), this.randomStringGenerator, this.authenticationHashGenerator);
    }

    @Override
    public BsonDocument createSpeculativeAuthenticateCommand(InternalConnection connection) {
        try {
            this.speculativeSaslClient = this.createSaslClient(connection.getDescription().getServerAddress());
            BsonDocument startDocument = this.createSaslStartCommandDocument(this.speculativeSaslClient.evaluateChallenge(new byte[0])).append("db", new BsonString(this.getMongoCredential().getSource()));
            this.appendSaslStartOptions(startDocument);
            return startDocument;
        }
        catch (Exception e) {
            throw this.wrapException(e);
        }
    }

    @Override
    public BsonDocument getSpeculativeAuthenticateResponse() {
        return this.speculativeAuthenticateResponse;
    }

    @Override
    public void setSpeculativeAuthenticateResponse(@Nullable BsonDocument response) {
        if (response == null) {
            this.speculativeSaslClient = null;
        } else {
            this.speculativeAuthenticateResponse = response;
        }
    }

    private static AuthenticationHashGenerator getAuthenicationHashGenerator(AuthenticationMechanism authenticationMechanism) {
        return authenticationMechanism == AuthenticationMechanism.SCRAM_SHA_1 ? LEGACY_AUTHENTICATION_HASH_GENERATOR : DEFAULT_AUTHENTICATION_HASH_GENERATOR;
    }

    private static class DefaultRandomStringGenerator
    implements RandomStringGenerator {
        private DefaultRandomStringGenerator() {
        }

        @Override
        public String generate(int length) {
            SecureRandom random = new SecureRandom();
            int comma = 44;
            int low = 33;
            int high = 126;
            int range = high - low;
            char[] text = new char[length];
            for (int i = 0; i < length; ++i) {
                int next = random.nextInt(range) + low;
                while (next == comma) {
                    next = random.nextInt(range) + low;
                }
                text[i] = (char)next;
            }
            return new String(text);
        }
    }

    public static interface AuthenticationHashGenerator {
        public String generate(MongoCredential var1);
    }

    public static interface RandomStringGenerator {
        public String generate(int var1);
    }

    class ScramShaSaslClient
    implements SaslClient {
        private final MongoCredentialWithCache credential;
        private final RandomStringGenerator randomStringGenerator;
        private final AuthenticationHashGenerator authenticationHashGenerator;
        private final String hAlgorithm;
        private final String hmacAlgorithm;
        private String clientFirstMessageBare;
        private String clientNonce;
        private byte[] serverSignature;
        private int step = -1;

        ScramShaSaslClient(MongoCredentialWithCache credential, RandomStringGenerator randomStringGenerator, AuthenticationHashGenerator authenticationHashGenerator) {
            this.credential = credential;
            this.randomStringGenerator = randomStringGenerator;
            this.authenticationHashGenerator = authenticationHashGenerator;
            if (Assertions.assertNotNull(credential.getAuthenticationMechanism()).equals((Object)AuthenticationMechanism.SCRAM_SHA_1)) {
                this.hAlgorithm = "SHA-1";
                this.hmacAlgorithm = "HmacSHA1";
            } else {
                this.hAlgorithm = "SHA-256";
                this.hmacAlgorithm = "HmacSHA256";
            }
        }

        @Override
        public String getMechanismName() {
            return Assertions.assertNotNull(this.credential.getAuthenticationMechanism()).getMechanismName();
        }

        @Override
        public boolean hasInitialResponse() {
            return true;
        }

        @Override
        public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
            ++this.step;
            if (this.step == 0) {
                return this.computeClientFirstMessage();
            }
            if (this.step == 1) {
                return this.computeClientFinalMessage(challenge);
            }
            if (this.step == 2) {
                return this.validateServerSignature(challenge);
            }
            throw new SaslException(String.format("Too many steps involved in the %s negotiation.", this.getMechanismName()));
        }

        private byte[] validateServerSignature(byte[] challenge) throws SaslException {
            String serverResponse = new String(challenge, StandardCharsets.UTF_8);
            HashMap<String, String> map = this.parseServerResponse(serverResponse);
            if (!MessageDigest.isEqual(Base64.getDecoder().decode(map.get("v")), this.serverSignature)) {
                throw new SaslException("Server signature was invalid.");
            }
            return new byte[0];
        }

        @Override
        public boolean isComplete() {
            return this.step == 2;
        }

        @Override
        public byte[] unwrap(byte[] incoming, int offset, int len) {
            throw new UnsupportedOperationException("Not implemented yet!");
        }

        @Override
        public byte[] wrap(byte[] outgoing, int offset, int len) {
            throw new UnsupportedOperationException("Not implemented yet!");
        }

        @Override
        public Object getNegotiatedProperty(String propName) {
            throw new UnsupportedOperationException("Not implemented yet!");
        }

        @Override
        public void dispose() {
        }

        private byte[] computeClientFirstMessage() {
            String clientFirstMessage;
            this.clientNonce = this.randomStringGenerator.generate(24);
            this.clientFirstMessageBare = clientFirstMessage = "n=" + this.getUserName() + ",r=" + this.clientNonce;
            return (ScramShaAuthenticator.GS2_HEADER + clientFirstMessage).getBytes(StandardCharsets.UTF_8);
        }

        private byte[] computeClientFinalMessage(byte[] challenge) throws SaslException {
            String serverFirstMessage = new String(challenge, StandardCharsets.UTF_8);
            HashMap<String, String> map = this.parseServerResponse(serverFirstMessage);
            String serverNonce = map.get("r");
            if (!serverNonce.startsWith(this.clientNonce)) {
                throw new SaslException("Server sent an invalid nonce.");
            }
            String salt = map.get("s");
            int iterationCount = Integer.parseInt(map.get("i"));
            if (iterationCount < 4096) {
                throw new SaslException("Invalid iteration count.");
            }
            String clientFinalMessageWithoutProof = "c=" + Base64.getEncoder().encodeToString(ScramShaAuthenticator.GS2_HEADER.getBytes(StandardCharsets.UTF_8)) + ",r=" + serverNonce;
            String authMessage = this.clientFirstMessageBare + "," + serverFirstMessage + "," + clientFinalMessageWithoutProof;
            String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" + this.getClientProof(this.getAuthenicationHash(), salt, iterationCount, authMessage);
            return clientFinalMessage.getBytes(StandardCharsets.UTF_8);
        }

        String getClientProof(String password, String salt, int iterationCount, String authMessage) throws SaslException {
            String hashedPasswordAndSalt = new String(this.h((password + salt).getBytes(StandardCharsets.UTF_8)), StandardCharsets.UTF_8);
            CacheKey cacheKey = new CacheKey(hashedPasswordAndSalt, salt, iterationCount);
            CacheValue cachedKeys = ScramShaAuthenticator.this.getMongoCredentialWithCache().getFromCache(cacheKey, CacheValue.class);
            if (cachedKeys == null) {
                byte[] saltedPassword = this.hi(password.getBytes(StandardCharsets.UTF_8), Base64.getDecoder().decode(salt), iterationCount);
                byte[] clientKey = this.hmac(saltedPassword, "Client Key");
                byte[] serverKey = this.hmac(saltedPassword, "Server Key");
                cachedKeys = new CacheValue(clientKey, serverKey);
                ScramShaAuthenticator.this.getMongoCredentialWithCache().putInCache(cacheKey, new CacheValue(clientKey, serverKey));
            }
            this.serverSignature = this.hmac(cachedKeys.serverKey, authMessage);
            byte[] storedKey = this.h(cachedKeys.clientKey);
            byte[] clientSignature = this.hmac(storedKey, authMessage);
            byte[] clientProof = this.xor(cachedKeys.clientKey, clientSignature);
            return Base64.getEncoder().encodeToString(clientProof);
        }

        private byte[] h(byte[] data) throws SaslException {
            try {
                return MessageDigest.getInstance(this.hAlgorithm).digest(data);
            }
            catch (NoSuchAlgorithmException e) {
                throw new SaslException(String.format("Algorithm for '%s' could not be found.", this.hAlgorithm), e);
            }
        }

        private byte[] hi(byte[] password, byte[] salt, int iterations) throws SaslException {
            try {
                SecretKeySpec key = new SecretKeySpec(password, this.hmacAlgorithm);
                Mac mac = Mac.getInstance(this.hmacAlgorithm);
                mac.init(key);
                mac.update(salt);
                mac.update(INT_1);
                byte[] result = mac.doFinal();
                byte[] previous = null;
                for (int i = 1; i < iterations; ++i) {
                    mac.update(previous != null ? previous : result);
                    previous = mac.doFinal();
                    this.xorInPlace(result, previous);
                }
                return result;
            }
            catch (NoSuchAlgorithmException e) {
                throw new SaslException(String.format("Algorithm for '%s' could not be found.", this.hmacAlgorithm), e);
            }
            catch (InvalidKeyException e) {
                throw new SaslException(String.format("Invalid key for %s", this.hmacAlgorithm), e);
            }
        }

        private byte[] hmac(byte[] bytes, String key) throws SaslException {
            try {
                Mac mac = Mac.getInstance(this.hmacAlgorithm);
                mac.init(new SecretKeySpec(bytes, this.hmacAlgorithm));
                return mac.doFinal(key.getBytes(StandardCharsets.UTF_8));
            }
            catch (NoSuchAlgorithmException e) {
                throw new SaslException(String.format("Algorithm for '%s' could not be found.", this.hmacAlgorithm), e);
            }
            catch (InvalidKeyException e) {
                throw new SaslException("Could not initialize mac.", e);
            }
        }

        private HashMap<String, String> parseServerResponse(String response) {
            String[] pairs;
            HashMap<String, String> map = new HashMap<String, String>();
            for (String pair : pairs = response.split(",")) {
                String[] parts = pair.split("=", 2);
                map.put(parts[0], parts[1]);
            }
            return map;
        }

        private String getUserName() {
            String userName = this.credential.getCredential().getUserName();
            if (userName == null) {
                throw new IllegalArgumentException("Username can not be null");
            }
            return userName.replace("=", "=3D").replace(",", "=2C");
        }

        private String getAuthenicationHash() {
            String password = this.authenticationHashGenerator.generate(this.credential.getCredential());
            if (this.credential.getAuthenticationMechanism() == AuthenticationMechanism.SCRAM_SHA_256) {
                password = SaslPrep.saslPrepStored(password);
            }
            return password;
        }

        private byte[] xorInPlace(byte[] a, byte[] b) {
            for (int i = 0; i < a.length; ++i) {
                int n = i;
                a[n] = (byte)(a[n] ^ b[i]);
            }
            return a;
        }

        private byte[] xor(byte[] a, byte[] b) {
            byte[] result = new byte[a.length];
            System.arraycopy(a, 0, result, 0, a.length);
            return this.xorInPlace(result, b);
        }
    }

    private static class CacheValue {
        private final byte[] clientKey;
        private final byte[] serverKey;

        CacheValue(byte[] clientKey, byte[] serverKey) {
            this.clientKey = clientKey;
            this.serverKey = serverKey;
        }
    }

    private static class CacheKey {
        private final String hashedPasswordAndSalt;
        private final String salt;
        private final int iterationCount;

        CacheKey(String hashedPasswordAndSalt, String salt, int iterationCount) {
            this.hashedPasswordAndSalt = hashedPasswordAndSalt;
            this.salt = salt;
            this.iterationCount = iterationCount;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            CacheKey that = (CacheKey)o;
            if (this.iterationCount != that.iterationCount) {
                return false;
            }
            if (!this.hashedPasswordAndSalt.equals(that.hashedPasswordAndSalt)) {
                return false;
            }
            return this.salt.equals(that.salt);
        }

        public int hashCode() {
            int result = this.hashedPasswordAndSalt.hashCode();
            result = 31 * result + this.salt.hashCode();
            result = 31 * result + this.iterationCount;
            return result;
        }
    }
}

