/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.hmm;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.hmm.HiddenMarkovModel;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;

@PublicationReference(author={"William Turin"}, title="Unidirectional and Parallel Baum\u2013Welch Algorithms", type=PublicationType.Journal, publication="IEEE Transactions on Speech and Audio Processing", year=1998, url="http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=00725318")
public class ParallelHiddenMarkovModel<ObservationType>
extends HiddenMarkovModel<ObservationType>
implements ParallelAlgorithm {
    private transient ThreadPoolExecutor threadPool;
    protected transient ArrayList<ObservationLikelihoodTask<ObservationType>> observationLikelihoodTasks;
    protected transient ArrayList<ComputeTransitionsTask> computeTransitionTasks;
    protected transient ArrayList<NormalizeTransitionTask> normalizeTransitionTasks;
    protected transient ArrayList<StateObservationLikelihoodTask> stateObservationLikelihoodTasks;
    protected transient ArrayList<ViterbiTask> viterbiTasks;

    public ParallelHiddenMarkovModel() {
    }

    public ParallelHiddenMarkovModel(int numStates) {
        super(numStates);
    }

    public ParallelHiddenMarkovModel(Vector initialProbability, Matrix transitionProbability, Collection<? extends ComputableDistribution<ObservationType>> emissionFunctions) {
        super(initialProbability, transitionProbability, emissionFunctions);
    }

    public ParallelHiddenMarkovModel(HiddenMarkovModel<ObservationType> other) {
        this((Vector)ObjectUtil.cloneSafe((CloneableSerializable)other.getInitialProbability()), (Matrix)ObjectUtil.cloneSafe((CloneableSerializable)other.getTransitionProbability()), ObjectUtil.cloneSmartElementsAsArrayList(other.getEmissionFunctions()));
    }

    public ThreadPoolExecutor getThreadPool() {
        if (this.threadPool == null) {
            this.setThreadPool(ParallelUtil.createThreadPool());
        }
        return this.threadPool;
    }

    public void setThreadPool(ThreadPoolExecutor threadPool) {
        this.threadPool = threadPool;
    }

    public int getNumThreads() {
        return ParallelUtil.getNumThreads((ParallelAlgorithm)this);
    }

    @Override
    public double computeMultipleObservationLogLikelihood(Collection<? extends Collection<? extends ObservationType>> sequences) {
        ArrayList<LogLikelihoodTask> tasks = new ArrayList<LogLikelihoodTask>(sequences.size());
        for (Collection<ObservationType> collection : sequences) {
            tasks.add(new LogLikelihoodTask(collection));
        }
        ArrayList results = null;
        try {
            results = ParallelUtil.executeInParallel(tasks, (ThreadPoolExecutor)this.getThreadPool());
        }
        catch (Exception exception) {
            throw new RuntimeException(exception);
        }
        double d = 0.0;
        for (int i = 0; i < results.size(); ++i) {
            d += ((Double)results.get(i)).doubleValue();
        }
        return d;
    }

    @Override
    protected Matrix computeTransitions(ArrayList<WeightedValue<Vector>> alphas, ArrayList<WeightedValue<Vector>> betas, ArrayList<Vector> b) {
        int N = alphas.size();
        if (this.computeTransitionTasks == null) {
            this.computeTransitionTasks = new ArrayList(N - 1);
        }
        this.computeTransitionTasks.ensureCapacity(N - 1);
        while (this.computeTransitionTasks.size() > N - 1) {
            this.computeTransitionTasks.remove(this.computeTransitionTasks.size() - 1);
        }
        while (this.computeTransitionTasks.size() < N - 1) {
            this.computeTransitionTasks.add(new ComputeTransitionsTask());
        }
        for (int n = 0; n < N - 1; ++n) {
            ComputeTransitionsTask tn = this.computeTransitionTasks.get(n);
            tn.alphan = (Vector)alphas.get(n).getValue();
            tn.betanp1 = (Vector)betas.get(n + 1).getValue();
            tn.bnp1 = b.get(n + 1);
        }
        RingAccumulator counts = new RingAccumulator();
        Matrix A = null;
        try {
            List futures = this.getThreadPool().invokeAll(this.computeTransitionTasks);
            for (Future future : futures) {
                counts.accumulate((Ring)future.get());
            }
            A = (Matrix)counts.getSum();
            A.dotTimesEquals((Ring)this.getTransitionProbability());
            this.normalizeTransitionMatrix(A);
        }
        catch (Exception ex) {
            throw new RuntimeException(ex);
        }
        return A;
    }

    @Override
    protected void normalizeTransitionMatrix(Matrix A) {
        int k = A.getNumColumns();
        if (this.normalizeTransitionTasks == null) {
            this.normalizeTransitionTasks = new ArrayList(k);
        }
        this.normalizeTransitionTasks.ensureCapacity(k);
        while (this.normalizeTransitionTasks.size() > k) {
            this.normalizeTransitionTasks.remove(this.normalizeTransitionTasks.size() - 1);
        }
        while (this.normalizeTransitionTasks.size() < k) {
            this.normalizeTransitionTasks.add(new NormalizeTransitionTask());
        }
        for (int j = 0; j < k; ++j) {
            NormalizeTransitionTask task = this.normalizeTransitionTasks.get(j);
            task.j = j;
            task.A = A;
        }
        try {
            ParallelUtil.executeInParallel(this.normalizeTransitionTasks, (ThreadPoolExecutor)this.getThreadPool());
        }
        catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    @Override
    protected ArrayList<Vector> computeStateObservationLikelihood(ArrayList<WeightedValue<Vector>> alphas, ArrayList<WeightedValue<Vector>> betas, double scaleFactor) {
        int N = alphas.size();
        if (this.stateObservationLikelihoodTasks == null) {
            this.stateObservationLikelihoodTasks = new ArrayList(N);
        }
        this.stateObservationLikelihoodTasks.ensureCapacity(N);
        while (this.stateObservationLikelihoodTasks.size() > N) {
            this.stateObservationLikelihoodTasks.remove(this.stateObservationLikelihoodTasks.size() - 1);
        }
        while (this.stateObservationLikelihoodTasks.size() < N) {
            this.stateObservationLikelihoodTasks.add(new StateObservationLikelihoodTask());
        }
        for (int n = 0; n < N; ++n) {
            StateObservationLikelihoodTask task = this.stateObservationLikelihoodTasks.get(n);
            task.alpha = (Vector)alphas.get(n).getValue();
            task.beta = (Vector)betas.get(n).getValue();
        }
        ArrayList gammas = null;
        try {
            gammas = ParallelUtil.executeInParallel(this.stateObservationLikelihoodTasks, (ThreadPoolExecutor)this.getThreadPool());
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return gammas;
    }

    @Override
    protected Pair<Vector, int[]> computeViterbiRecursion(Vector delta, Vector bn) {
        ArrayList results;
        int k = this.getNumStates();
        if (this.viterbiTasks == null) {
            this.viterbiTasks = new ArrayList(k);
        }
        this.viterbiTasks.ensureCapacity(k);
        while (this.viterbiTasks.size() > k) {
            this.viterbiTasks.remove(this.viterbiTasks.size() - 1);
        }
        while (this.viterbiTasks.size() < k) {
            this.viterbiTasks.add(new ViterbiTask());
        }
        int i = 0;
        while (i < k) {
            ViterbiTask task = this.viterbiTasks.get(i);
            task.destinationState = i++;
            task.delta = delta;
        }
        try {
            results = ParallelUtil.executeInParallel(this.viterbiTasks, (ThreadPoolExecutor)this.getThreadPool());
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        int[] psis = new int[k];
        Vector nextDelta = VectorFactory.getDefault().createVector(k);
        for (int i2 = 0; i2 < k; ++i2) {
            WeightedValue value = (WeightedValue)results.get(i2);
            psis[i2] = (Integer)value.getValue();
            nextDelta.setElement(i2, value.getWeight());
        }
        nextDelta.dotTimesEquals((Ring)bn);
        nextDelta.scaleEquals(1.0 / nextDelta.norm1());
        return DefaultPair.create((Object)nextDelta, (Object)psis);
    }

    protected class LogLikelihoodTask
    extends AbstractCloneableSerializable
    implements Callable<Double> {
        protected Collection<? extends ObservationType> data;

        public LogLikelihoodTask(Collection<? extends ObservationType> data) {
            this.data = data;
        }

        @Override
        public Double call() throws Exception {
            return ParallelHiddenMarkovModel.this.computeObservationLogLikelihood(this.data);
        }
    }

    protected class ViterbiTask
    extends AbstractCloneableSerializable
    implements Callable<WeightedValue<Integer>> {
        int destinationState;
        Vector delta;

        ViterbiTask() {
        }

        @Override
        public WeightedValue<Integer> call() throws Exception {
            return ParallelHiddenMarkovModel.this.findMostLikelyState(this.destinationState, this.delta);
        }
    }

    protected static class ComputeTransitionsTask
    extends AbstractCloneableSerializable
    implements Callable<Matrix> {
        Vector alphan;
        Vector betanp1;
        Vector bnp1;

        @Override
        public Matrix call() {
            return ParallelHiddenMarkovModel.computeTransitions(this.alphan, this.betanp1, this.bnp1);
        }
    }

    protected static class NormalizeTransitionTask
    extends AbstractCloneableSerializable
    implements Callable<Void> {
        private Matrix A;
        private int j;

        @Override
        public Void call() {
            ParallelHiddenMarkovModel.normalizeTransitionMatrix(this.A, this.j);
            return null;
        }
    }

    protected static class StateObservationLikelihoodTask
    extends AbstractCloneableSerializable
    implements Callable<Vector> {
        protected Vector alpha;
        protected Vector beta;

        @Override
        public Vector call() throws Exception {
            return ParallelHiddenMarkovModel.computeStateObservationLikelihood(this.alpha, this.beta, 1.0);
        }
    }

    protected static class ObservationLikelihoodTask<ObservationType>
    extends AbstractCloneableSerializable
    implements Callable<double[]> {
        protected Collection<? extends ObservationType> observations;
        protected ProbabilityFunction<ObservationType> distributionFunction;

        @Override
        public double[] call() {
            int N = this.observations.size();
            double[] b = new double[N];
            int n = 0;
            for (ObservationType observation : this.observations) {
                b[n] = (Double)this.distributionFunction.evaluate(observation);
                ++n;
            }
            return b;
        }
    }
}

