import { Injectable } from '@angular/core';

import { HtmlEntitiesnService } from '../utils/html-entities.service';

@Injectable({
    providedIn: 'root'
})
export class AiClipTokenizerService {

    byteDecoder: any;

    byteEncoder: any;

    bpeRanks: any;

    cache: any = {};

    decoder: any;

    encoder: any;

    pat: any;

    constructor(
        private htmlEntities: HtmlEntitiesnService,
    ) {
        this.byteEncoder = this.bytesToUnicode();
        this.byteDecoder = Object.fromEntries(Object.entries(this.byteEncoder).map(([k, v]) => [v, k]));

        this.cache = { '<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>' };
        this.pat = /<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+/gui;
    }

    basicClean(text: string) {
        text = this.htmlEntities.decode(this.htmlEntities.decode(text));
        return text.trim();
    }

    bpe(token: string) {

        if (this.cache[token] !== undefined) {
            return this.cache[token];
        }

        let word: any = [...token.slice(0, -1), token.slice(-1) + '</w>'];
        let pairs: any = this.getPairs(word);

        if (pairs.length === 0) {
            return token + '</w>';
        }

        while (1) {
            let bigram = null;
            let minRank = Infinity;

            for (let p of pairs) {
                let r = this.bpeRanks[p.join("·😎·")];
                if (r === undefined) continue;
                if (r < minRank) {
                    minRank = r;
                    bigram = p;
                }
            }

            if (bigram === null) {
                break;
            }

            let [first, second] = bigram as any;
            let newWord: any = [];
            let i: number = 0;

            while (i < word.length) {

                let j = word.indexOf(first, i);

                if (j === -1) {
                    newWord.push(...word.slice(i));
                    break;
                }

                newWord.push(...word.slice(i, j));
                i = j;

                if (word[i] === first && i < word.length - 1 && word[i + 1] === second) {
                    newWord.push(first + second);
                    i += 2;
                } else {
                    newWord.push(word[i]);
                    i += 1;
                }
            }

            word = newWord;

            if (word.length === 1) {
                break;
            } else {
                pairs = this.getPairs(word);
            }
        }

        word = word.join(" ");

        this.cache[token] = word;

        return word;
    }

    bytesToUnicode() {
        let bs = [
            ...this.range(this.ord("!"), this.ord("~") + 1),
            ...this.range(this.ord("¡"), this.ord("¬") + 1),
            ...this.range(this.ord("®"), this.ord("ÿ") + 1),
        ];

        let cs: any = bs.slice(0);
        let n: number = 0;

        for (let b of this.range(2 ** 8)) {
            if (!bs.includes(b)) {
                bs.push(b);
                cs.push(2 ** 8 + n as any);
                n += 1;
            }
        }

        cs = cs.map(n => String.fromCharCode(n));

        return Object.fromEntries(bs.map((v, i) => [v, cs[i]]));
    }

    decode(tokens) {

        let text: string = tokens.map((token: string) => {
            return this.decoder[token];
        }).join("");

        text = [...text].map(c => this.byteDecoder[c]).map(v => String.fromCharCode(v)).join("").replaceAll('</w>', ' ');

        return text;
    }

    encode(text) {
        let bpeTokens: any[] = [];

        text = this.whitespaceClean(text).toLowerCase();

        for (let token of [...text.matchAll(this.pat)].map(m => m[0])) {

            token = [...token].map((b: any) => {
                return this.byteEncoder[b.charCodeAt(0)];
            }).join("");

            bpeTokens.push(...this.bpe(token).split(' ').map((bpe_token: string) => {
                return this.encoder[bpe_token];
            }));
        }

        return bpeTokens;
    }

    // adds start and end token, and adds padding 0's and ensures it's 77 tokens long
    encodeForCLIP(text) {
        let tokens: any = this.encode(text);
        tokens.unshift(49406); // start token
        tokens = tokens.slice(0, 76);
        tokens.push(49407); // end token
        while (tokens.length < 77) tokens.push(49407);
        return tokens;
    }

    getPairs(word) {
        let pairs: any = [];
        let prevChar: any = word[0];

        for (let char of word.slice(1)) {
            pairs.push([prevChar, char]);
            prevChar = char;
        }

        return pairs;
    }

    init() {
        return new Promise(async (resolve, reject) => {
            try {

                if (!!this.encoder && !!this.decoder && !!this.bpeRanks) {
                    resolve(true);
                    return true;
                }

                const response: any = await fetch('./assets/lib/ai/bpe_simple_vocab.json');
                const bpeVocabData: any = await response.json();

                if (!bpeVocabData || !bpeVocabData.text) {
                    reject(false);
                    return false;
                }

                let merges: any[] = bpeVocabData.text.split("\n");
                merges = merges.slice(1, 49152 - 256 - 2 + 1);
                merges = merges.map(merge => merge.split(" "));

                // There was a bug related to the ordering of Python's .values() output. I'm lazy do I've just copy-pasted the Python output:
                let vocab: string[] = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '¡', '¢', '£', '¤', '¥', '¦', '§', '¨', '©', 'ª', '«', '¬', '®', '¯', '°', '±', '²', '³', '´', 'µ', '¶', '·', '¸', '¹', 'º', '»', '¼', '½', '¾', '¿', 'À', 'Á', 'Â', 'Ã', 'Ä', 'Å', 'Æ', 'Ç', 'È', 'É', 'Ê', 'Ë', 'Ì', 'Í', 'Î', 'Ï', 'Ð', 'Ñ', 'Ò', 'Ó', 'Ô', 'Õ', 'Ö', '×', 'Ø', 'Ù', 'Ú', 'Û', 'Ü', 'Ý', 'Þ', 'ß', 'à', 'á', 'â', 'ã', 'ä', 'å', 'æ', 'ç', 'è', 'é', 'ê', 'ë', 'ì', 'í', 'î', 'ï', 'ð', 'ñ', 'ò', 'ó', 'ô', 'õ', 'ö', '÷', 'ø', 'ù', 'ú', 'û', 'ü', 'ý', 'þ', 'ÿ', 'Ā', 'ā', 'Ă', 'ă', 'Ą', 'ą', 'Ć', 'ć', 'Ĉ', 'ĉ', 'Ċ', 'ċ', 'Č', 'č', 'Ď', 'ď', 'Đ', 'đ', 'Ē', 'ē', 'Ĕ', 'ĕ', 'Ė', 'ė', 'Ę', 'ę', 'Ě', 'ě', 'Ĝ', 'ĝ', 'Ğ', 'ğ', 'Ġ', 'ġ', 'Ģ', 'ģ', 'Ĥ', 'ĥ', 'Ħ', 'ħ', 'Ĩ', 'ĩ', 'Ī', 'ī', 'Ĭ', 'ĭ', 'Į', 'į', 'İ', 'ı', 'Ĳ', 'ĳ', 'Ĵ', 'ĵ', 'Ķ', 'ķ', 'ĸ', 'Ĺ', 'ĺ', 'Ļ', 'ļ', 'Ľ', 'ľ', 'Ŀ', 'ŀ', 'Ł', 'ł', 'Ń'];
                vocab = [...vocab, ...vocab.map(v => v + '</w>')];

                for (let merge of merges) {
                    vocab.push(merge.join(""));
                }

                vocab.push('<|startoftext|>', '<|endoftext|>');

                this.encoder = Object.fromEntries(vocab.map((v, i) => [v, i]));
                this.decoder = Object.fromEntries(Object.entries(this.encoder).map(([k, v]) => [v, k]));
                this.bpeRanks = Object.fromEntries(merges.map((v, i) => [v.join("·😎·"), i])); // ·😎· because js doesn't yet have tuples

                resolve(true);
            } catch (e: any) {
                reject(e);
            }
        });
    }

    ord(c: string) {
        return c.charCodeAt(0);
    }

    range(start: number, stop: any = undefined, step: number = 1) {

        if (stop === undefined) {
            stop = start;
            start = 0;
        }

        if ((step > 0 && start >= stop) || (step < 0 && start <= stop)) {
            return [];
        }

        const result: number[] = [];

        for (let i: number = start; step > 0 ? i < stop : i > stop; i += step) {
            result.push(i);
        }

        return result;
    }

    whitespaceClean(text: string) {
        return `${text || ''}`.replace(/\s+/g, " ").trim();
    }

}