packages/causality-optimizers/src/trainer.mixins.js
/**
* This mixin class provides attributes: **Optimizer**, **Trainer**, **TrainDataGenerator**,
* methods **train**, handle **Optimizer** setting of pipelineConfig.Net
* and **TrainDataGenerator** setting of pipelineConfig.Dataset.
*
* @class TrainerMixins
* @extends {BasePipelineClass}
* @example
* [EXAMPLE ../examples/trainer.mixins.babel.js]
*/
const TrainerMixins = (BasePipelineClass)=> class extends BasePipelineClass{
get Trainer(){
const T = this.T;
const Loss=this.LossModel, Optimizer=this.Optimizer;
return (sampleTensor, labelTensor)=>{
const LossFn = ()=>{
return T.tidy( ()=>{
Loss(sampleTensor, labelTensor);
return Loss(sampleTensor, labelTensor);
} );
};
return Optimizer.fit(LossFn);
};
}
get Optimizer(){
if(!this.optimizer){
throw Error('optimizer is not set');
}
return this.optimizer;
}
set Optimizer(optimizer){
this.optimizer = optimizer;
}
get TrainDataGenerator(){
if(!this.trainDataGenerator){
throw Error('TrainDataGenerator is not set');
}
return this.trainDataGenerator;
}
set TrainDataGenerator(TrainDataGenerator){
this.trainDataGenerator = TrainDataGenerator;
}
async train(numEpochs, batchSize){
const F = this.F, R = this.F.CoreFunctor, T = this.T;
const TrainDataGenerator = this.TrainDataGenerator, Trainer = this.Trainer;
let losses = [], logger = this.Logger;
return new Promise(async (resolve, reject)=>{
logger.progressBegin(numEpochs);
for(let epochIdx of F.range(numEpochs)){
const TrainData = TrainDataGenerator(batchSize);
let iterLosses = [];
for await (let { samples, labels } of TrainData){
let sampleTensor = T.tensor(samples).asType('float32');
console.log(sampleTensor.shape);
let labelTensor = T.tensor(labels).asType('float32');
let loss = Trainer(sampleTensor, labelTensor);
iterLosses.push(await loss.data());
}
losses.push(R.mean(iterLosses));
iterLosses = [];
logger.progressUpdate({epochIdx, losses, numEpochs});
}
logger.progressEnd();
resolve({losses});
});
}
setByConfig(pipelineConfig){
if(super.setByConfig){
super.setByConfig(pipelineConfig);
}
this.Logger.groupBegin('set Trainer by config');
const { Optimizer } = pipelineConfig.Net;
this.Optimizer = Optimizer;
Optimizer.LayerRunner = this.LayerRunner;
this.TrainDataGenerator = pipelineConfig.Dataset.TrainDataGenerator;
this.Logger.groupEnd();
}
};
export default TrainerMixins;