Home Reference Source Test

packages/causality-layer/src/CausalNetRunner/runner.mixins.js

/**
 * This RunnerMixins class provide methods for runner class.
 * @class RunnerMixins
 * @extends BaseRunnerClass
 */
const RunnerMixins = ( BaseRunnerClass )=> class extends BaseRunnerClass{
    set NetParameters(parameters){
        this.netParameters = parameters;
    }
    set NetLayers(netLayers){
        this.netLayers = netLayers;
    }
    get NetParameters(){
        if(!this.netParameters){
            throw Error('netParameters is not set');
        }
        return this.netParameters;
    }
    get NetLayers(){
        if(!this.netLayers){
            throw Error('netLayers is not set');
        }
        return this.netLayers;
    }
    

    runLayer(value, layerConfigure, layerParameters){
        const {Name, Type, Flow, Net} = layerConfigure;
        if(Type === 'Tensor'){
            let result = Flow(value);
            return {[Name]: result};
        }   
        else if(Type === 'Layer'){
            let result = Net(value, layerParameters);
            return {[Name]: result};
        }   
        else{
            throw Error('type must be either Layer or Tensor');
        }
    }

    run(layers, samples, parameters){
        let pipeValue = {PipeInput: samples}, lastLayer = 'PipeInput';
        for(let layer of layers){
            let layerOutput = this.runLayer(pipeValue[lastLayer], layer, parameters[layer.Name]);
            pipeValue[layer.Name] = layerOutput[layer.Name];
            lastLayer = layer.Name;
        }
        return pipeValue[lastLayer];
    }

    get Predictor(){
        let predictLayers = this.NetLayers.Predict;
        const PredictParametersLenses = ()=>this.NetParameters.PredictParameters;
        return (samples)=>{
            let predictParameters = PredictParametersLenses();
            return this.run(predictLayers, samples, predictParameters);
        };
    }
    get Encoder(){
        let encodeLayers = this.NetLayers.Encode;
        const EncodeParametersLenses = ()=>this.NetParameters.EncodeParameters;
        return (samples)=>{
            let encodeParameters = EncodeParametersLenses();
            return this.run(encodeLayers, samples, encodeParameters);
        };
    }
    get Decoder(){
        let decodeLayers = this.NetLayers.Decode;
        const DecodeParametersLenses = ()=>this.NetParameters.EncodeParameters;
        return (samples)=>{
            let decodeParameters = DecodeParametersLenses();
            return this.run(decodeLayers, samples, decodeParameters);
        };
    }
};

export default RunnerMixins;