/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.optimisers.util.ShrinkingMatrix;
import org.tribuo.math.optimisers.util.ShrinkingTensor;
import org.tribuo.math.optimisers.util.ShrinkingVector;

public class Pegasos
implements StochasticGradientOptimiser {
    @Config(description="Step size shrinkage.")
    private double lambda = 0.01;
    @Config(description="Base learning rate.")
    private double baseRate = 0.1;
    private int iteration = 1;
    private Parameters parameters;

    private Pegasos() {
    }

    public Pegasos(double baseRate, double lambda) {
        this.baseRate = baseRate;
        this.lambda = lambda;
    }

    @Override
    public void initialise(Parameters parameters) {
        this.parameters = parameters;
        Tensor[] curParams = parameters.get();
        Tensor[] newParams = new Tensor[curParams.length];
        for (int i = 0; i < newParams.length; ++i) {
            if (curParams[i] instanceof DenseVector) {
                newParams[i] = new ShrinkingVector((DenseVector)curParams[i], this.baseRate, this.lambda);
                continue;
            }
            if (curParams[i] instanceof DenseMatrix) {
                newParams[i] = new ShrinkingMatrix((DenseMatrix)curParams[i], this.baseRate, this.lambda);
                continue;
            }
            throw new IllegalStateException("Unknown Tensor subclass");
        }
        parameters.set(newParams);
    }

    @Override
    public Tensor[] step(Tensor[] updates, double weight) {
        double eta_t = this.baseRate / (this.lambda * (double)this.iteration);
        for (Tensor t : updates) {
            t.scaleInPlace(eta_t * weight);
        }
        ++this.iteration;
        return updates;
    }

    public String toString() {
        return "Pegasos(baseRate=" + this.baseRate + ",lambda=" + this.lambda + ")";
    }

    @Override
    public void finalise() {
        Tensor[] curParams = this.parameters.get();
        Tensor[] newParams = new Tensor[curParams.length];
        for (int i = 0; i < newParams.length; ++i) {
            if (!(curParams[i] instanceof ShrinkingTensor)) {
                throw new IllegalStateException("Finalising a Parameters which wasn't initialised with Pegasos");
            }
            newParams[i] = ((ShrinkingTensor)((Object)curParams[i])).convertToDense();
        }
        this.parameters.set(newParams);
    }

    @Override
    public void reset() {
        this.parameters = null;
        this.iteration = 1;
    }

    @Override
    public Pegasos copy() {
        return new Pegasos(this.lambda, this.baseRate);
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "StochasticGradientOptimiser");
    }
}

