Home Reference Source Test

packages/causality-representation/src/universalEmbedding.js

import { platform, fetch } from 'causal-net.utils';
import { Tensor } from 'causal-net.core';
import { default as functor } from './functor';
import { default as VectorMetricMixins } from './vectorMetrics.mixins';

/**
 * This UniversalEmbedding provide methods for transform sentences into 
 * its single vector representation 
 * based on [use](https://github.com/tensorflow/tfjs-models/tree/master/universal-sentence-encoder)
 * 
 * @class UniversalEmbedding
 * @extends {Tensor}
 * @example
 * [EXAMPLE ../examples/universalEmbedding.babel.js]
 */
class UniversalEmbedding extends platform.mixWith(Tensor, 
    [ VectorMetricMixins ]){
    
    constructor(functor){
        super();
        /**
         * @private { TensorModel } 
         */
        this.model = null;
        this.f = functor;
        this.R = this.f.CoreFunctor;
        this.vecSize = 512;
    }
    
    async connect(link){
        this.model = await this.T.loadGraphModel(link);
        return this;
    }

    async sentenceEncode(sentences, asTensor=true){
        if(!this.model){
            throw Error(`model is not connect`);
        }
        const T = this.T;
        //only one sentence
        let embeddings = [];
        for(let tokenIdxs of sentences){
            let idxs = tokenIdxs.map((t,i)=>[0,i]);
            let indices = T.tensor(idxs, [tokenIdxs.length, 2], 'int32');
            let values = T.tensor(tokenIdxs, [tokenIdxs.length], 'int32');
            embeddings.push( await this.model.executeAsync({indices, values}) );
        }
        let sentTensor = T.stack(embeddings);
        if(asTensor){
            return sentTensor;
        }
        else{
            let vec = await sentTensor.data();
            sentTensor.dispose();
            return this.R.splitEvery(this.vecSize, vec);
        }
    }

    async transform(tokenIdxs){
        let [vec] = await this.sentenceEncode([tokenIdxs], false);
        return Array.from(vec);
    }
}
export default new UniversalEmbedding(functor);