packages/causality-layer/src/CausalNetRunner/runner.mixins.js
/**
* This RunnerMixins class provide methods for runner class.
* @class RunnerMixins
* @extends BaseRunnerClass
*/
const RunnerMixins = ( BaseRunnerClass )=> class extends BaseRunnerClass{
set NetParameters(parameters){
this.netParameters = parameters;
}
set NetLayers(netLayers){
this.netLayers = netLayers;
}
get NetParameters(){
if(!this.netParameters){
throw Error('netParameters is not set');
}
return this.netParameters;
}
get NetLayers(){
if(!this.netLayers){
throw Error('netLayers is not set');
}
return this.netLayers;
}
runLayer(value, layerConfigure, layerParameters){
const {Name, Type, Flow, Net} = layerConfigure;
if(Type === 'Tensor'){
let result = Flow(value);
return {[Name]: result};
}
else if(Type === 'Layer'){
let result = Net(value, layerParameters);
return {[Name]: result};
}
else{
throw Error('type must be either Layer or Tensor');
}
}
run(layers, samples, parameters){
let pipeValue = {PipeInput: samples}, lastLayer = 'PipeInput';
for(let layer of layers){
let layerOutput = this.runLayer(pipeValue[lastLayer], layer, parameters[layer.Name]);
pipeValue[layer.Name] = layerOutput[layer.Name];
lastLayer = layer.Name;
}
return pipeValue[lastLayer];
}
get Predictor(){
let predictLayers = this.NetLayers.Predict;
const PredictParametersLenses = ()=>this.NetParameters.PredictParameters;
return (samples)=>{
let predictParameters = PredictParametersLenses();
return this.run(predictLayers, samples, predictParameters);
};
}
get Encoder(){
let encodeLayers = this.NetLayers.Encode;
const EncodeParametersLenses = ()=>this.NetParameters.EncodeParameters;
return (samples)=>{
let encodeParameters = EncodeParametersLenses();
return this.run(encodeLayers, samples, encodeParameters);
};
}
get Decoder(){
let decodeLayers = this.NetLayers.Decode;
const DecodeParametersLenses = ()=>this.NetParameters.EncodeParameters;
return (samples)=>{
let decodeParameters = DecodeParametersLenses();
return this.run(decodeLayers, samples, decodeParameters);
};
}
};
export default RunnerMixins;