import { useCallback, useRef, useMemo, useState, useEffect } from "react";
import _ from "lodash";
import { compareTwoGridStates, getAllRows } from "../utils";

export const useGridUndoRedo = (gridApi, updateData, maxStackSize = 5) => {
  const [stack, setStack] = useState([]);
  const [stackIdxPointer, setStackIdxPointer] = useState(null);
  const [isUndoRedoInProgress, setIsUndoRedoInProgress] = useState(false);

  useEffect(() => {
    if (isUndoRedoInProgress === true) {
      setTimeout(() => {
        setIsUndoRedoInProgress(false);
      }, 700);
    }
  }, [isUndoRedoInProgress]);

  const isGridUpdating = useRef(false);
  const isFirstSnapshotLoaded = useRef(false);

  const currentState = useMemo(() => {
    return stack[stackIdxPointer];
  }, [stack, stackIdxPointer]);

  const initialiseUndoRedoStack = useCallback(() => {
    if (gridApi && !isFirstSnapshotLoaded.current) {
      const gridData = getAllRows(gridApi);
      const gridSnapshot = _.cloneDeep(gridData);
      setStack([gridSnapshot]);
      setStackIdxPointer(0);

      isFirstSnapshotLoaded.current = true;
    }
  }, [gridApi]);

  const flashCells = (updateInfo) => {
    for (const key in updateInfo) {
      const rowIds = updateInfo[key].map((item) => item.id);
      const rowNodes = [];
      rowIds.forEach((rowId) => {
        const rowNode = gridApi.getRowNode(rowId);
        rowNodes.push(rowNode);
      });
      gridApi.flashCells({ rowNodes: rowNodes, columns: [key] });
    }
  };

  const undo = () => {
    if (isUndoRedoInProgress || stackIdxPointer === 0) {
      return;
    }

    setIsUndoRedoInProgress(true);

    const newStackIdxPointer = Math.max(0, Number(stackIdxPointer) - 1);
    const prevStateSnapshot = stack[newStackIdxPointer];
    const prevStateSnapshotCopy = _.cloneDeep(prevStateSnapshot);

    if (currentState.length !== prevStateSnapshotCopy.length) {
      return;
    }

    const updateInfo = compareTwoGridStates(currentState, prevStateSnapshot);
    setStackIdxPointer(newStackIdxPointer);
    gridApi.setRowData(_.cloneDeep(prevStateSnapshotCopy));

    flashCells(updateInfo);
    updateData(prevStateSnapshotCopy, updateInfo);
  };

  const redo = () => {
    if (isUndoRedoInProgress || stackIdxPointer === stack.length - 1) {
      return;
    }

    setIsUndoRedoInProgress(true);

    const newStackIdxPointer = Math.min(
      stack.length - 1,
      Number(stackIdxPointer) + 1
    );
    const nextStateSnapshot = stack[newStackIdxPointer];
    const nextStateSnapshotCopy = _.cloneDeep(nextStateSnapshot);

    if (currentState.length !== nextStateSnapshotCopy.length) {
      return;
    }

    const updateInfo = compareTwoGridStates(
      currentState,
      nextStateSnapshotCopy
    );
    setStackIdxPointer(newStackIdxPointer);
    gridApi.setRowData(_.cloneDeep(nextStateSnapshotCopy));
    updateData(nextStateSnapshotCopy, updateInfo);

    flashCells(updateInfo);
  };

  const handleKeyDownForUndoRedo = (e) => {
    if (e === "undo") {
      undo();
    } else if (e === "redo") {
      redo();
    } else if (e.metaKey && e.key === "z") {
      undo();
    } else if (e.ctrlKey && e.key === "z") {
      undo();
    } else if (e.metaKey && e.key === "y") {
      redo();
    } else if (e.ctrlKey && e.key === "y") {
      redo();
    }
  };

  const saveToStack = () => {
    const gridData = getAllRows(gridApi);
    const gridSnapshot = _.cloneDeep(gridData);
    const stackWithoutRedoableStates = stack.slice(0, stackIdxPointer + 1);

    let newStack = [...stackWithoutRedoableStates, gridSnapshot];
    let newIdxPointer = newStack.length - 1;
    if (newStack.length > maxStackSize) {
      const numElementsToRemove = newStack.length - maxStackSize;
      newStack = newStack.slice(numElementsToRemove);
      newIdxPointer = newIdxPointer - numElementsToRemove;
    }
    setStack(newStack);
    setStackIdxPointer(newIdxPointer);
    isGridUpdating.current = false;
  };

  const triggerUndoRedoSnapshotCollection = () => {
    if (isGridUpdating.current) {
      return;
    }
    isGridUpdating.current = true;

    setTimeout(() => {
      saveToStack();
    }, 100);
  };

  const resetUndoRedoStack = useCallback(() => {
    const gridData = getAllRows(gridApi);
    const gridSnapshot = _.cloneDeep(gridData);
    setStack([gridSnapshot]);
    setStackIdxPointer(0);
  }, [gridApi]);

  return {
    handleKeyDownForUndoRedo,
    triggerUndoRedoSnapshotCollection,
    initialiseUndoRedoStack,
    resetUndoRedoStack,
  };
};
