Home Reference Source Test

packages/causality-preprocessing/src/Tokenizer/sentencePiece.js

/**
 * This code origin from 
 * https://github.com/tensorflow/tfjs-models/tree/master/universal-sentence-encoder
 * 
 * Tokenizer.encode() is a port of `EncodeAsIds` from the SentencePiece library
 * (https://github.com/google/sentencepiece). Encode uses the Viterbi algorithm
 * to find the most likely sequence of tokens that comprise the input. For more
 * details, refer to https://arxiv.org/pdf/1804.10959.pdf.
 */

import { default as stringToChars } from './utils';

import { default as Trie } from './trie';

const separator = '\u2581';  // This is the unicode character 'lower one eighth block'.

function processInput(str) {
  const normalized = str.normalize('NFKC');
  return separator + normalized.replace(/ /g, separator);
}

// The first tokens are reserved for unk, control symbols, and user-defined
// symbols.
const RESERVED_SYMBOLS_COUNT = 6;

// type Vocabulary = Array<[string, number]>;

// type Score = {
//   key: string[],
//   score: number,
//   index: number
// };

export default class SentencePieceTokenizer {

  constructor(Vocabulary) {
    this.vocabulary = Vocabulary;
    this.trie = new Trie();

    for (let i = RESERVED_SYMBOLS_COUNT; i < this.vocabulary.length; i++) {
      this.trie.insert(this.vocabulary[i][0], this.vocabulary[i][1], i);
    }
  }

  encode(input){
    const nodes = [];
    const words = [];
    const best = [];

    input = processInput(input);

    const symbols = stringToChars(input);

    for (let i = 0; i <= symbols.length; i++) {
      nodes.push({});
      words.push(0);
      best.push(0);
    }

    // Construct the lattice.
    for (let i = 0; i < symbols.length; i++) {
      const matches = this.trie.commonPrefixSearch(symbols.slice(i));

      for (let j = 0; j < matches.length; j++) {
        const piece = matches[j];
        const obj = {key: piece[0], score: piece[1], index: piece[2]};

        const endPos = piece[0].length;
        if (nodes[i + endPos][i] == null) {
          nodes[i + endPos][i] = [];
        }

        nodes[i + endPos][i].push(obj);
      }
    }

    for (let endPos = 0; endPos <= symbols.length; endPos++) {
      for (const startPos in nodes[endPos]) {
        const arr = nodes[endPos][startPos];

        for (let j = 0; j < arr.length; j++) {
          const word = arr[j];
          const score = word.score + best[endPos - word.key.length];

          if (best[endPos] === 0 || score >= best[endPos]) {
            best[endPos] = score;
            words[endPos] = arr[j].index;
          }
        }
      }
    }

    const results = [];

    // Backward pass.
    let iter = words.length - 1;
    while (iter > 0) {
      results.push(words[iter]);
      iter -= this.vocabulary[words[iter]][0].length;
    }

    // Merge consecutive unks.
    const merged = [];
    let isPreviousUnk = false;
    for (let i = 0; i < results.length; i++) {
      const id = results[i];
      if (!(isPreviousUnk && id === 0)) {
        merged.push(id);
      }

      isPreviousUnk = id === 0;
    }
    
    return merged.reverse();
  }
}