packages/causality-optimizers/src/evaluator.mixins.js
/**
* This mixin class provides methods: **test** and
* handle **TestDataGenerator** setting of pipelineConfig.Dataset.
*
* @class EvaluatorMixins
* @extends {BasePipelineClass}
* @example
* [EXAMPLE ../examples/trainer.mixins.babel.js]
*/
const EvaluatorMixins = (BasePipelineClass)=> class extends BasePipelineClass{
get TestDataGenerator(){
if(!this.testDataGenerator){
throw Error('testDataGenerator is not set');
}
return this.testDataGenerator;
}
set TestDataGenerator(testDataGenerator){
this.testDataGenerator = testDataGenerator;
}
async test(metric='accuracy', batchSize=32){
const T = this.T, R = this.F.CoreFunctor;
const TestDataGenerator = this.TestDataGenerator, OneHotPredictModel = this.OneHotPredictModel;
return new Promise(async (resolve, reject)=>{
const testData = TestDataGenerator(batchSize);
let pass = [];
for await (let { samples, labels } of testData){
let sampleTensor = T.tensor(samples).asType('float32');
let labelTensor = T.tensor(labels).asType('float32');
let predictLabelTensor = OneHotPredictModel(sampleTensor);
let correctPredicts = predictLabelTensor.mul(labelTensor);
pass = [...pass, ... await correctPredicts.sum(1).data()];
}
let accuracy = R.mean(pass);
resolve({accuracy, pass});
});
}
setByConfig(pipelineConfig){
if(super.setByConfig){
super.setByConfig(pipelineConfig);
}
this.Logger.groupBegin('set Evaluator by config');
this.TestDataGenerator = pipelineConfig.Dataset.TestDataGenerator;
this.Logger.groupEnd();
}
};
export default EvaluatorMixins;