import { ParsedPQModel } from './pqModelUtils';

const LEFT_NODE_INDEX = 0;
const RIGHT_NODE_INDEX = 1;

/**
 * Our input row for PQ prediction of an OFP should only consist of variables that our model was trained on.
 * Running the prediction of an OFP from a newer GLP run against a PQ model trained with an old GLP model is
 * guaranteed to have extra reference flavors in the OFP. This function drops those new keys.
 * @param inputRow A single row of an OFP.
 * @param modelVariableNames Variable Names in the Random Forest Model.
 * @returns A single row of an OFP with model trained variables as keys.
 */
export const sanitizePredictionInput = (
  inputRow: { [x: string]: number },
  modelVariableNames: ParsedPQModel['independentVariableNames'],
): { [x: string]: number } => {
  const sanitizedInputKeys = Object.keys(inputRow).filter((x) =>
    modelVariableNames.includes(x),
  );
  const sanitizedRow = sanitizedInputKeys.map((key) => [key, inputRow[key]]);
  return Object.fromEntries(sanitizedRow);
};

/**
 * Given a parent index, get the left child node.
 * @param treeID
 * @param nodeIndex
 * @param childNodeIds
 * @returns Left node
 */
export const getLeftNodeLabel = (
  treeID: number,
  nodeIndex: number,
  childNodeIds: [number[], number[]][],
) => {
  return childNodeIds[treeID][LEFT_NODE_INDEX][nodeIndex];
};

/**
 * Given a parent index, get the right child node.
 * @param treeID
 * @param nodeIndex
 * @param childNodeIds
 * @returns Right node
 */
export const getRightNodeLabel = (
  treeID: number,
  nodeIndex: number,
  childNodeIds: [number[], number[]][],
) => {
  return childNodeIds[treeID][RIGHT_NODE_INDEX][nodeIndex];
};

/**
 * Update the node index to the left or right child node
 * based on whether the input row flavor/texture value is less/greater than the decision value.
 * @param decisionValue The decision threshold from model
 * @param inputValue returned from getSplitVariableValue
 * @param leftChildNode
 * @param rightChildNode
 */
export const getNextNodeIndex = (
  decisionValue: number,
  inputValue: number,
  leftChildNode: number,
  rightChildNode: number,
): number => {
  if (inputValue <= decisionValue) {
    return leftChildNode;
  } else {
    return rightChildNode;
  }
};

/**
 * Gets the key of the next variable in splitValues
 * This is the decision value.
 * Having a flavor/texture value less than the decision value would
 * mean a left node traversal.
 * @param treeID
 * @param nodeIndex
 * @param independentVariableNames
 * @param splitVarIds
 * @returns Key of the variable we are making a decision on.
 */
export const getSplitVariableKey = (
  treeID: number,
  nodeIndex: number,
  independentVariableNames: ParsedPQModel['independentVariableNames'],
  splitVarIds: number[][],
): string => {
  const splitIndex = splitVarIds[treeID][nodeIndex];
  return independentVariableNames[splitIndex];
};

/**
 * Gets the RFF / Texture value from the input row to be compared against the decision threshold.
 * This is FROM the input data. Not the decision value from the model!
 * Decision value comes from getDecisionThresholdValue
 * @param splitVarID
 * @param inputRow
 * @returns The relevant RFF / Texture value tuple of the input row
 */
export const getSplitVariableValue = (
  splitVarKey: string,
  inputRow: { [x: string]: number },
) => {
  return inputRow[splitVarKey];
};

/**
 * Get the Decision Value from the model based on tree and node index.
 * If the output from getSplitVariableValue is less than the decision value,
 * we traverse left.
 * @param treeID
 * @param nodeIndex
 * @param splitValues
 * @returns The model's decision value to be compared with the input's split value.
 */
export const getDecisionThresholdValue = (
  treeID: number,
  nodeIndex: number,
  splitValues: ParsedPQModel['splitValues'],
) => splitValues[treeID][nodeIndex];

/**
 * Get the terminal class count at the given terminal node index.
 * @param treeID
 * @param nodeIndex
 * @param terminalClassCounts
 * @returns an Array of length 7 containings 0s and a 1.
 */
export const getTerminalClassCount = (
  treeID: number,
  nodeIndex: number,
  terminalClassCounts: number[][][],
) => terminalClassCounts[treeID][nodeIndex];

/**
 * Convert the output from getTerminalClassCount to a PQ value.
 * @param classValues From model.classValues which tells is which index maps to which PQ value
 * @param terminalClassCount The Decision Tree's PQ prediction array.
 * @returns PQ value constrained between 1 - 7
 */
export const terminalClassToPQ = (
  classValues: ParsedPQModel['classValues'],
  terminalClassCount: number[],
): number => {
  const predictedPQ = terminalClassCount.reduce((acc, terminalClass, index) => {
    return acc + terminalClass * classValues[index];
  }, 0);
  return Math.min(Math.max(predictedPQ, 1), 7);
};

/**
 * The number of decision points in a given tree.
 * We should not be traversing more than there are decision points.
 * This is used for setting a limit when traversing the tree.
 * @param treeID
 * @param terminalClassCounts
 * @returns number of decision points / meaningful terminal classes in the tree
 */
export const numberOfDecisionPointsInTree = (
  treeID: number,
  terminalClassCounts: ParsedPQModel['terminalClassCounts'],
) =>
  terminalClassCounts[treeID].filter(
    (terminalClasses) => terminalClasses.length == 7,
  ).length;

/**
 * Traverse the tree and return a PQ value averaged across decision points.
 * @param treeID
 * @param model
 * @param inputRow
 * @returns Final Predicted PQ by the tree
 */
export const getDecisionTreePQ = (
  treeID: number,
  model: ParsedPQModel,
  inputRow: { [x: string]: number },
): number => {
  const maxDecisionPoints = numberOfDecisionPointsInTree(
    treeID,
    model.terminalClassCounts,
  );

  let nodeIndex = 0;
  for (let i = 0; i < maxDecisionPoints; i++) {
    const leftChildNode = getLeftNodeLabel(
      treeID,
      nodeIndex,
      model.childNodeIds,
    );
    const rightChildNode = getRightNodeLabel(
      treeID,
      nodeIndex,
      model.childNodeIds,
    );
    if (leftChildNode == 0 && rightChildNode == 0) {
      const terminalClass = getTerminalClassCount(
        treeID,
        nodeIndex,
        model.terminalClassCounts,
      );
      return terminalClassToPQ(model.classValues, terminalClass);
    }

    const splitVarID = getSplitVariableKey(
      treeID,
      nodeIndex,
      model.independentVariableNames,
      model.splitVarIds,
    );
    const splitValue = getSplitVariableValue(splitVarID, inputRow);
    const decisionValue = model.splitValues[treeID][nodeIndex];
    nodeIndex = getNextNodeIndex(
      decisionValue,
      splitValue,
      leftChildNode,
      rightChildNode,
    );
  }
  throw new Error('Cannot compute PQ value');
};

/**
 * Each tree has a predicted PQ for a given input row.
 * This function returns the mean across all the trees.
 * @param model
 * @param inputRow
 * @returns Final Predicted PQ value from random forest
 */
export const getForestPQ = (
  model: ParsedPQModel,
  inputRow: { [x: string]: number },
) => {
  const treeIndices = Array.from(Array(model.numTrees).keys());
  const finalPQ =
    treeIndices.reduce(
      (acc, treeID) => acc + getDecisionTreePQ(treeID, model, inputRow),
      0,
    ) / model.numTrees;
  return finalPQ;
};

/**
 * This function returns the mean PQ over multiple rows of OFP / Market DB.
 * Each row has its PQ computed by getForestPQ
 * @param model
 * @param inputRows
 * @returns
 */
export const getFinalMeanPQ = (
  model: ParsedPQModel,
  inputRows: { [x: string]: number }[],
) =>
  inputRows.reduce(
    (acc, row) =>
      acc +
      getForestPQ(
        model,
        sanitizePredictionInput(row, model.independentVariableNames),
      ),
    0,
  ) / inputRows.length;
