/*
 * Decompiled with CFR 0.152.
 */
package cl.mc3d.gpt4all;

import cl.mc3d.gpt4all.LLModelLibrary;
import cl.mc3d.gpt4all.Util;
import com.hexadevlabs.gpt4all.PromptIsTooLongException;
import java.io.ByteArrayOutputStream;
import java.io.UnsupportedEncodingException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import jnr.ffi.Pointer;
import jnr.ffi.Runtime;
import jnr.ffi.byref.PointerByReference;

public class LLModel
implements AutoCloseable {
    public static String LIBRARY_SEARCH_PATH;
    public static boolean OUTPUT_DEBUG;
    protected static LLModelLibrary library;
    protected Pointer model;
    protected String modelName;
    private String gpt4allVersion;

    public static GenerationConfig.Builder config() {
        return new GenerationConfig.Builder();
    }

    LLModel() {
    }

    public LLModel(String gpt4allVersion, Path modelPath, String libraryPath, int contextSize) {
        this.gpt4allVersion = gpt4allVersion;
        if (library == null) {
            if (libraryPath != null) {
                library = Util.loadSharedLibrary(libraryPath);
                library.llmodel_set_implementation_search_path(libraryPath);
            } else {
                Path tempLibraryDirectory = Util.copySharedLibraries();
                library = Util.loadSharedLibrary(tempLibraryDirectory.toString());
                Logger.getLogger(LLModel.class.getName()).log(Level.INFO, "llmodel_set_implementation_search_path: {0}", libraryPath);
                library.llmodel_set_implementation_search_path(tempLibraryDirectory.toString());
            }
        }
        this.modelName = modelPath.getFileName().toString();
        String modelPathAbs = modelPath.toAbsolutePath().toString();
        PointerByReference error = new PointerByReference();
        if (!Files.exists(modelPath, new LinkOption[0])) {
            throw new IllegalStateException("Model file does not exist: " + modelPathAbs);
        }
        if (!Files.isReadable(modelPath)) {
            throw new IllegalStateException("Model file cannot be read: " + modelPathAbs);
        }
        this.model = library.llmodel_model_create2(modelPathAbs, "auto", error);
        if (this.model == null) {
            throw new IllegalStateException("Could not load, gpt4all backend returned error: " + ((Pointer)error.getValue()).getString(0L));
        }
        library.llmodel_loadModel(this.model, modelPathAbs, contextSize, 100);
        if (!library.llmodel_isModelLoaded(this.model)) {
            throw new IllegalStateException("The model " + this.modelName + " could not be loaded");
        }
    }

    public void setThreadCount(int nThreads) {
        library.llmodel_setThreadCount(this.model, nThreads);
    }

    public int threadCount() {
        return library.llmodel_threadCount(this.model);
    }

    public String generate(String prompt, GenerationConfig generationConfig, String prompt_template, boolean special, String fake_reply) throws UnsupportedEncodingException {
        return this.generate(prompt, generationConfig, prompt_template, false, special, fake_reply);
    }

    public String generate(String prompt, GenerationConfig generationConfig, String prompt_template, boolean streamToStdOut, boolean special, String sfake_reply) throws UnsupportedEncodingException {
        PointerByReference fakeReply = new PointerByReference();
        ByteArrayOutputStream bufferingForStdOutStream = new ByteArrayOutputStream();
        ByteArrayOutputStream bufferingForWholeGeneration = new ByteArrayOutputStream();
        LLModelLibrary.ResponseCallback responseCallback = LLModel.getResponseCallback(streamToStdOut, bufferingForStdOutStream, bufferingForWholeGeneration);
        if (this.gpt4allVersion.contains("-")) {
            this.gpt4allVersion = this.gpt4allVersion.substring(this.gpt4allVersion.indexOf("-") + 1, this.gpt4allVersion.length());
        }
        if (this.gpt4allVersion.equals("2.7.0")) {
            prompt = "### user: " + prompt + "### assistant:";
            library.llmodel_prompt(this.model, prompt, tokenID -> {
                if (OUTPUT_DEBUG) {
                    Logger.getLogger(LLModel.class.getName()).log(Level.INFO, "token " + tokenID);
                }
                return true;
            }, responseCallback, isRecalculating -> {
                if (OUTPUT_DEBUG) {
                    Logger.getLogger(LLModel.class.getName()).log(Level.INFO, "recalculating");
                }
                return isRecalculating;
            }, generationConfig);
        } else if (this.gpt4allVersion.equals("2.7.1")) {
            library.llmodel_prompt(this.model, prompt, prompt_template, tokenID -> {
                if (OUTPUT_DEBUG) {
                    Logger.getLogger(LLModel.class.getName()).log(Level.INFO, "token " + tokenID);
                }
                return true;
            }, responseCallback, isRecalculating -> {
                if (OUTPUT_DEBUG) {
                    Logger.getLogger(LLModel.class.getName()).log(Level.INFO, "recalculating");
                }
                return isRecalculating;
            }, generationConfig, special);
        } else {
            prompt_template = "### Human: %1 ### ### Assistant:";
            library.llmodel_prompt(this.model, prompt, prompt_template, tokenID -> {
                if (OUTPUT_DEBUG) {
                    Logger.getLogger(LLModel.class.getName()).log(Level.INFO, "token " + tokenID);
                }
                return true;
            }, responseCallback, isRecalculating -> {
                if (OUTPUT_DEBUG) {
                    Logger.getLogger(LLModel.class.getName()).log(Level.INFO, "recalculating");
                }
                return isRecalculating;
            }, generationConfig, special, sfake_reply);
        }
        return bufferingForWholeGeneration.toString("UTF-8");
    }

    public boolean llmodel_has_gpu_device() {
        return library.llmodel_has_gpu_device(this.model);
    }

    public boolean llmodel_isModelLoaded() {
        return library.llmodel_isModelLoaded(this.model);
    }

    public boolean llmodel_gpu_init_gpu_device_by_int(int device) {
        return library.llmodel_gpu_init_gpu_device_by_int(this.model, device);
    }

    static LLModelLibrary.ResponseCallback getResponseCallback(boolean streamToStdOut, ByteArrayOutputStream bufferingForStdOutStream, ByteArrayOutputStream bufferingForWholeGeneration) {
        return (tokenID, response) -> {
            byte nextByte;
            if (OUTPUT_DEBUG) {
                Logger.getLogger(LLModel.class.getName()).log(Level.INFO, "Response token " + tokenID + " ");
            }
            if (tokenID == -1) {
                throw new PromptIsTooLongException(response.getString(0L, 1000, StandardCharsets.UTF_8));
            }
            long len = 0L;
            do {
                try {
                    nextByte = response.getByte(len);
                }
                catch (IndexOutOfBoundsException e) {
                    throw new RuntimeException("Empty array or not null terminated");
                }
                ++len;
                if (nextByte == 0) continue;
                bufferingForWholeGeneration.write(nextByte);
                if (!streamToStdOut) continue;
                bufferingForStdOutStream.write(nextByte);
                byte[] currentBytes = bufferingForStdOutStream.toByteArray();
                String validString = Util.getValidUtf8(currentBytes);
                if (validString == null) continue;
                System.out.print(validString);
                bufferingForStdOutStream.reset();
            } while (nextByte != 0);
            return true;
        };
    }

    public CompletionReturn chatCompletionResponse(Messages messages, GenerationConfig generationConfig) throws UnsupportedEncodingException {
        return this.chatCompletion(messages, generationConfig, false, false);
    }

    public CompletionReturn chatCompletion(Messages messages, GenerationConfig generationConfig, boolean streamToStdOut, boolean outputFullPromptToStdOut) throws UnsupportedEncodingException {
        String fullPrompt = LLModel.buildPrompt(messages.toListMap());
        if (outputFullPromptToStdOut) {
            Logger.getLogger(LLModel.class.getName()).log(Level.INFO, fullPrompt);
        }
        String generatedText = this.generate(fullPrompt, generationConfig, "", streamToStdOut, true, null);
        CompletionChoice promptMessage = new CompletionChoice(Role.ASSISTANT, generatedText);
        Choices choices = new Choices(promptMessage);
        Usage usage = this.getUsage(fullPrompt, generatedText);
        return new CompletionReturn(this.modelName, usage, choices);
    }

    public ChatCompletionResponse chatCompletion(List<Map<String, String>> messages, GenerationConfig generationConfig) throws UnsupportedEncodingException {
        return this.chatCompletion(messages, generationConfig, false, false);
    }

    public ChatCompletionResponse chatCompletion(List<Map<String, String>> messages, GenerationConfig generationConfig, boolean streamToStdOut, boolean outputFullPromptToStdOut) throws UnsupportedEncodingException {
        String fullPrompt = LLModel.buildPrompt(messages);
        if (outputFullPromptToStdOut) {
            Logger.getLogger(LLModel.class.getName()).log(Level.INFO, fullPrompt);
        }
        String generatedText = this.generate(fullPrompt, generationConfig, "", streamToStdOut, true, null);
        ChatCompletionResponse response = new ChatCompletionResponse();
        response.model = this.modelName;
        response.usage = this.getUsage(fullPrompt, generatedText);
        HashMap<String, String> message = new HashMap<String, String>();
        message.put("role", "assistant");
        message.put("content", generatedText);
        ArrayList<Map<String, String>> lChoices = new ArrayList<Map<String, String>>();
        lChoices.add(message);
        response.choices = lChoices;
        return response;
    }

    private Usage getUsage(String fullPrompt, String generatedText) {
        Usage usage = new Usage();
        usage.promptTokens = fullPrompt.length();
        usage.completionTokens = generatedText.length();
        usage.totalTokens = fullPrompt.length() + generatedText.length();
        return usage;
    }

    protected static String buildPrompt(List<Map<String, String>> messages) {
        StringBuilder fullPrompt = new StringBuilder();
        for (Map<String, String> message : messages) {
            if (!"system".equals(message.get("role"))) continue;
            String systemMessage = message.get("content") + "\n";
            fullPrompt.append(systemMessage);
        }
        fullPrompt.append("### Instruction: \nThe prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.\n### Prompt: ");
        for (Map<String, String> message : messages) {
            if ("user".equals(message.get("role"))) {
                String userMessage = "\n" + message.get("content");
                fullPrompt.append(userMessage);
            }
            if (!"assistant".equals(message.get("role"))) continue;
            String assistantMessage = "\n### Response: " + message.get("content");
            fullPrompt.append(assistantMessage);
        }
        fullPrompt.append("\n### Response:");
        return fullPrompt.toString();
    }

    @Override
    public void close() throws Exception {
        library.llmodel_model_destroy(this.model);
    }

    static {
        OUTPUT_DEBUG = false;
    }

    public static class Usage {
        public int promptTokens;
        public int completionTokens;
        public int totalTokens;
    }

    public static class ChatCompletionResponse {
        public String model;
        public Usage usage;
        public List<Map<String, String>> choices;
    }

    public static class CompletionChoice
    extends PromptMessage {
        public CompletionChoice(Role role, String content) {
            super(role, content);
        }
    }

    public static class Choices {
        private final List<CompletionChoice> choices = new ArrayList<CompletionChoice>();

        public Choices(List<CompletionChoice> choices) {
            this.choices.addAll(choices);
        }

        public Choices(CompletionChoice ... completionChoices) {
            this.choices.addAll(Arrays.asList(completionChoices));
        }

        public Choices addCompletionChoice(CompletionChoice completionChoice) {
            this.choices.add(completionChoice);
            return this;
        }

        public CompletionChoice first() {
            return this.choices.get(0);
        }

        public int totalChoices() {
            return this.choices.size();
        }

        public CompletionChoice get(int index) {
            return this.choices.get(index);
        }

        public List<CompletionChoice> choices() {
            return Collections.unmodifiableList(this.choices);
        }
    }

    public static class CompletionReturn {
        private String model;
        private Usage usage;
        private Choices choices;

        public CompletionReturn(String model, Usage usage, Choices choices) {
            this.model = model;
            this.usage = usage;
            this.choices = choices;
        }

        public Choices choices() {
            return this.choices;
        }

        public String model() {
            return this.model;
        }

        public Usage usage() {
            return this.usage;
        }
    }

    public static enum Role {
        SYSTEM("system"),
        ASSISTANT("assistant"),
        USER("user");

        private final String type;

        String type() {
            return this.type;
        }

        static Role from(String type) {
            if (type == null) {
                return null;
            }
            switch (type) {
                case "system": {
                    return SYSTEM;
                }
                case "assistant": {
                    return ASSISTANT;
                }
                case "user": {
                    return USER;
                }
            }
            throw new IllegalArgumentException(String.format("You passed %s type but only %s are supported", type, Arrays.toString((Object[])Role.values())));
        }

        private Role(String type) {
            this.type = type;
        }

        public String toString() {
            return this.type();
        }
    }

    public static class PromptMessage {
        private static final String ROLE = "role";
        private static final String CONTENT = "content";
        private final Map<String, String> message = new HashMap<String, String>();

        public PromptMessage() {
        }

        public PromptMessage(Role role, String content) {
            this.addRole(role);
            this.addContent(content);
        }

        public PromptMessage addRole(Role role) {
            return this.addParameter(ROLE, role.type());
        }

        public PromptMessage addContent(String content) {
            return this.addParameter(CONTENT, content);
        }

        public PromptMessage addParameter(String key, String value) {
            this.message.put(key, value);
            return this;
        }

        public String content() {
            return this.parameter(CONTENT);
        }

        public Role role() {
            String role = this.parameter(ROLE);
            return Role.from(role);
        }

        public String parameter(String key) {
            return this.message.get(key);
        }

        Map<String, String> toMap() {
            return Collections.unmodifiableMap(this.message);
        }
    }

    public static class Messages {
        private final List<PromptMessage> messages = new ArrayList<PromptMessage>();

        public Messages(PromptMessage ... messages) {
            this.messages.addAll(Arrays.asList(messages));
        }

        public Messages(List<PromptMessage> messages) {
            this.messages.addAll(messages);
        }

        public Messages addPromptMessage(PromptMessage promptMessage) {
            this.messages.add(promptMessage);
            return this;
        }

        List<PromptMessage> toList() {
            return Collections.unmodifiableList(this.messages);
        }

        List<Map<String, String>> toListMap() {
            return this.messages.stream().map(PromptMessage::toMap).collect(Collectors.toList());
        }
    }

    public static class GenerationConfig
    extends LLModelLibrary.LLModelPromptContext {
        private GenerationConfig() {
            super(Runtime.getSystemRuntime());
            this.logits_size.set(0L);
            this.tokens_size.set(0L);
            this.n_past.set(0L);
            this.n_ctx.set(4096L);
            this.n_predict.set(128L);
            this.top_k.set(40L);
            this.top_p.set((Number)0.95);
            this.temp.set((Number)0.28);
            this.n_batch.set(8L);
            this.repeat_penalty.set((Number)1.1);
            this.repeat_last_n.set(10L);
            this.context_erase.set((Number)0.55);
        }

        public static class Builder {
            private final GenerationConfig configToBuild = new GenerationConfig();

            public Builder withNPast(int n_past) {
                this.configToBuild.n_past.set((long)n_past);
                return this;
            }

            public Builder withNCtx(int n_ctx) {
                this.configToBuild.n_ctx.set((long)n_ctx);
                return this;
            }

            public Builder withNPredict(int n_predict) {
                this.configToBuild.n_predict.set((long)n_predict);
                return this;
            }

            public Builder withTopK(int top_k) {
                this.configToBuild.top_k.set((long)top_k);
                return this;
            }

            public Builder withTopP(float top_p) {
                this.configToBuild.top_p.set(top_p);
                return this;
            }

            public Builder withTemp(float temp) {
                this.configToBuild.temp.set(temp);
                return this;
            }

            public Builder withNBatch(int n_batch) {
                this.configToBuild.n_batch.set((long)n_batch);
                return this;
            }

            public Builder withRepeatPenalty(float repeat_penalty) {
                this.configToBuild.repeat_penalty.set(repeat_penalty);
                return this;
            }

            public Builder withRepeatLastN(int repeat_last_n) {
                this.configToBuild.repeat_last_n.set((long)repeat_last_n);
                return this;
            }

            public Builder withContextErase(float context_erase) {
                this.configToBuild.context_erase.set(context_erase);
                return this;
            }

            public GenerationConfig build() {
                return this.configToBuild;
            }
        }
    }
}

