Home Reference Source Test

packages/causality-models/src/singleLableClassification.js

import { default as BaseModel } from './baseModel';

class SingleLabelClassification extends BaseModel{
    constructor(numClass){
        super();
        if(numClass > 0){
            this.numClass = numClass;
        }
        else{
            throw Error(`expect numclass, get ${numClass}`);
        }
    }

    set LayerRunner(layerRunner){
        let { Predictor } = layerRunner;
        this.runner = { Predictor };
    }

    get LayerRunner(){
        if(!this.runner){
            throw Error('runner is not set');
        }
        return this.runner;
    }
    
    get Fit(){
        const { Predictor } = this.LayerRunner;
        return (inputTensor)=>{
            let outPutTensor = Predictor(inputTensor);
            let logProb = outPutTensor.sub(outPutTensor.logSumExp(1, true));
            return logProb;
        };
    }
    
    get Predict(){
        const Fit = this.Fit;
        return (inputTensor)=>{
            let logProb = Fit(inputTensor);
            let predictedClass = logProb.argMax(1);
            return predictedClass;
        };
    }

    get OneHotPredict(){
        const Predict = this.Predict;
        return (inputTensor)=>{
            let predictedClass = Predict(inputTensor);
            let oneHotPredict = this.T.oneHot(predictedClass, this.numClass);
            return oneHotPredict;
        };
    }
    get Loss(){
        const Fit = this.Fit;
        return (inputTensor, labelTensor)=>{
            let logProb = Fit(inputTensor);
            let likelihood = logProb.neg().mul(labelTensor);
            let loss = likelihood.sum(1).mean();
            return loss;
        };
    }
}
export default SingleLabelClassification;