/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.common.model;

import java.io.IOException;
import java.security.AccessController;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput;
import org.opensearch.ml.common.model.Guardrail;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.transport.client.Client;

public class ModelGuardrail
extends Guardrail {
    @Generated
    private static final Logger log = LogManager.getLogger(ModelGuardrail.class);
    public static final String MODEL_ID_FIELD = "model_id";
    public static final String RESPONSE_FILTER_FIELD = "response_filter";
    public static final String RESPONSE_VALIDATION_REGEX_FIELD = "response_validation_regex";
    private String modelId;
    private String responseFilter;
    private String responseAccept;
    private NamedXContentRegistry xContentRegistry;
    private Client client;
    private SdkClient sdkClient;
    private String tenantId;
    private Pattern regexAcceptPattern;

    public ModelGuardrail(String modelId, String responseFilter, String responseAccept) {
        this.modelId = modelId;
        this.responseFilter = responseFilter;
        this.responseAccept = responseAccept;
    }

    public ModelGuardrail(@NonNull Map<String, Object> params) {
        this((String)params.get(MODEL_ID_FIELD), (String)params.get(RESPONSE_FILTER_FIELD), (String)params.get(RESPONSE_VALIDATION_REGEX_FIELD));
        Objects.requireNonNull(params, "params is marked non-null but is null");
    }

    public ModelGuardrail(StreamInput input) throws IOException {
        this.modelId = input.readString();
        this.responseFilter = input.readString();
        this.responseAccept = input.readString();
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        out.writeString(this.modelId);
        out.writeString(this.responseFilter);
        out.writeString(this.responseAccept);
    }

    private Boolean validateAcceptRegex(String input) {
        Matcher matcher = this.regexAcceptPattern.matcher(input);
        return matcher.matches();
    }

    @Override
    public Boolean validate(String in, Map<String, String> parameters2) {
        String input;
        String string = input = parameters2 == null ? null : parameters2.get("question");
        if (input == null || input.isEmpty()) {
            log.info("Guardrail request is empty.");
            return true;
        }
        log.info("Guardrail request: {}", (Object)input);
        AtomicBoolean isAccepted = new AtomicBoolean(true);
        ActionListener internalListener = ActionListener.wrap(predictionResponse -> {
            ModelTensorOutput output = (ModelTensorOutput)predictionResponse.getOutput();
            ModelTensor tensor = output.getMlModelOutputs().get(0).getMlModelTensors().get(0);
            String guardrailResponse = AccessController.doPrivileged(() -> StringUtils.gson.toJson(tensor.getDataAsMap().get("response")));
            log.info("Guardrail response: {}", (Object)guardrailResponse);
            if (!this.validateAcceptRegex(guardrailResponse).booleanValue()) {
                isAccepted.set(false);
            }
        }, e -> log.error("[ModelGuardrail] Failed to get prediction response.", (Throwable)e));
        ActionListener<MLTaskResponse> actionListener = this.wrapActionListener(internalListener, res -> {
            MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse(res);
            return predictionResponse;
        });
        CountDownLatch latch = new CountDownLatch(1);
        HashMap<String, String> guardrailModelParams = new HashMap<String, String>();
        guardrailModelParams.put("question", input);
        if (this.responseFilter != null && !this.responseFilter.isEmpty()) {
            guardrailModelParams.put(RESPONSE_FILTER_FIELD, this.responseFilter);
        }
        log.info("Guardrail resFilter: {}", (Object)this.responseFilter);
        MLPredictionTaskRequest request = new MLPredictionTaskRequest(this.modelId, RemoteInferenceMLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(RemoteInferenceInputDataSet.builder().parameters(guardrailModelParams).build()).build());
        this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, (ActionRequest)request, (ActionListener)new LatchedActionListener(actionListener, latch));
        try {
            latch.await(5L, TimeUnit.SECONDS);
        }
        catch (InterruptedException e2) {
            log.error("[ModelGuardrail] Validation was timeout.", (Throwable)e2);
        }
        return isAccepted.get();
    }

    @Override
    public void init(NamedXContentRegistry xContentRegistry, Client client, SdkClient sdkClient, String tenantId) {
        this.xContentRegistry = xContentRegistry;
        this.client = client;
        this.sdkClient = sdkClient;
        this.tenantId = tenantId;
        this.regexAcceptPattern = Pattern.compile(this.responseAccept);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        if (this.modelId != null) {
            builder.field(MODEL_ID_FIELD, this.modelId);
        }
        if (this.responseFilter != null) {
            builder.field(RESPONSE_FILTER_FIELD, this.responseFilter);
        }
        if (this.responseAccept != null) {
            builder.field(RESPONSE_VALIDATION_REGEX_FIELD, this.responseAccept);
        }
        builder.endObject();
        return builder;
    }

    public static ModelGuardrail parse(XContentParser parser) throws IOException {
        String modelId = null;
        String responseFilter = null;
        String responseAccept = null;
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
        block10: while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();
            switch (fieldName) {
                case "model_id": {
                    modelId = parser.text();
                    continue block10;
                }
                case "response_filter": {
                    responseFilter = parser.text();
                    continue block10;
                }
                case "response_validation_regex": {
                    responseAccept = parser.text();
                    continue block10;
                }
            }
            parser.skipChildren();
        }
        return ModelGuardrail.builder().modelId(modelId).responseFilter(responseFilter).responseAccept(responseAccept).build();
    }

    private <T extends ActionResponse> ActionListener<T> wrapActionListener(ActionListener<T> listener, Function<ActionResponse, T> recreate) {
        ActionListener actionListener = ActionListener.wrap(r -> listener.onResponse((Object)((ActionResponse)recreate.apply((ActionResponse)r))), e -> listener.onFailure(e));
        return actionListener;
    }

    @Generated
    public static ModelGuardrailBuilder builder() {
        return new ModelGuardrailBuilder();
    }

    @Generated
    public ModelGuardrailBuilder toBuilder() {
        return new ModelGuardrailBuilder().modelId(this.modelId).responseFilter(this.responseFilter).responseAccept(this.responseAccept);
    }

    @Generated
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ModelGuardrail)) {
            return false;
        }
        ModelGuardrail other = (ModelGuardrail)o;
        if (!other.canEqual(this)) {
            return false;
        }
        String this$modelId = this.getModelId();
        String other$modelId = other.getModelId();
        if (this$modelId == null ? other$modelId != null : !this$modelId.equals(other$modelId)) {
            return false;
        }
        String this$responseFilter = this.getResponseFilter();
        String other$responseFilter = other.getResponseFilter();
        if (this$responseFilter == null ? other$responseFilter != null : !this$responseFilter.equals(other$responseFilter)) {
            return false;
        }
        String this$responseAccept = this.getResponseAccept();
        String other$responseAccept = other.getResponseAccept();
        if (this$responseAccept == null ? other$responseAccept != null : !this$responseAccept.equals(other$responseAccept)) {
            return false;
        }
        NamedXContentRegistry this$xContentRegistry = this.getXContentRegistry();
        NamedXContentRegistry other$xContentRegistry = other.getXContentRegistry();
        if (this$xContentRegistry == null ? other$xContentRegistry != null : !this$xContentRegistry.equals(other$xContentRegistry)) {
            return false;
        }
        Client this$client = this.getClient();
        Client other$client = other.getClient();
        if (this$client == null ? other$client != null : !this$client.equals(other$client)) {
            return false;
        }
        SdkClient this$sdkClient = this.getSdkClient();
        SdkClient other$sdkClient = other.getSdkClient();
        if (this$sdkClient == null ? other$sdkClient != null : !this$sdkClient.equals(other$sdkClient)) {
            return false;
        }
        String this$tenantId = this.getTenantId();
        String other$tenantId = other.getTenantId();
        if (this$tenantId == null ? other$tenantId != null : !this$tenantId.equals(other$tenantId)) {
            return false;
        }
        Pattern this$regexAcceptPattern = this.getRegexAcceptPattern();
        Pattern other$regexAcceptPattern = other.getRegexAcceptPattern();
        return !(this$regexAcceptPattern == null ? other$regexAcceptPattern != null : !this$regexAcceptPattern.equals(other$regexAcceptPattern));
    }

    @Generated
    protected boolean canEqual(Object other) {
        return other instanceof ModelGuardrail;
    }

    @Generated
    public int hashCode() {
        int PRIME = 59;
        int result2 = 1;
        String $modelId = this.getModelId();
        result2 = result2 * 59 + ($modelId == null ? 43 : $modelId.hashCode());
        String $responseFilter = this.getResponseFilter();
        result2 = result2 * 59 + ($responseFilter == null ? 43 : $responseFilter.hashCode());
        String $responseAccept = this.getResponseAccept();
        result2 = result2 * 59 + ($responseAccept == null ? 43 : $responseAccept.hashCode());
        NamedXContentRegistry $xContentRegistry = this.getXContentRegistry();
        result2 = result2 * 59 + ($xContentRegistry == null ? 43 : $xContentRegistry.hashCode());
        Client $client = this.getClient();
        result2 = result2 * 59 + ($client == null ? 43 : $client.hashCode());
        SdkClient $sdkClient = this.getSdkClient();
        result2 = result2 * 59 + ($sdkClient == null ? 43 : $sdkClient.hashCode());
        String $tenantId = this.getTenantId();
        result2 = result2 * 59 + ($tenantId == null ? 43 : $tenantId.hashCode());
        Pattern $regexAcceptPattern = this.getRegexAcceptPattern();
        result2 = result2 * 59 + ($regexAcceptPattern == null ? 43 : $regexAcceptPattern.hashCode());
        return result2;
    }

    @Generated
    public String getModelId() {
        return this.modelId;
    }

    @Generated
    public String getResponseFilter() {
        return this.responseFilter;
    }

    @Generated
    public String getResponseAccept() {
        return this.responseAccept;
    }

    @Generated
    public NamedXContentRegistry getXContentRegistry() {
        return this.xContentRegistry;
    }

    @Generated
    public Client getClient() {
        return this.client;
    }

    @Generated
    public SdkClient getSdkClient() {
        return this.sdkClient;
    }

    @Generated
    public String getTenantId() {
        return this.tenantId;
    }

    @Generated
    public Pattern getRegexAcceptPattern() {
        return this.regexAcceptPattern;
    }

    @Generated
    public static class ModelGuardrailBuilder {
        @Generated
        private String modelId;
        @Generated
        private String responseFilter;
        @Generated
        private String responseAccept;

        @Generated
        ModelGuardrailBuilder() {
        }

        @Generated
        public ModelGuardrailBuilder modelId(String modelId) {
            this.modelId = modelId;
            return this;
        }

        @Generated
        public ModelGuardrailBuilder responseFilter(String responseFilter) {
            this.responseFilter = responseFilter;
            return this;
        }

        @Generated
        public ModelGuardrailBuilder responseAccept(String responseAccept) {
            this.responseAccept = responseAccept;
            return this;
        }

        @Generated
        public ModelGuardrail build() {
            return new ModelGuardrail(this.modelId, this.responseFilter, this.responseAccept);
        }

        @Generated
        public String toString() {
            return "ModelGuardrail.ModelGuardrailBuilder(modelId=" + this.modelId + ", responseFilter=" + this.responseFilter + ", responseAccept=" + this.responseAccept + ")";
        }
    }
}

