import { useCallback, useEffect, useState, useRef } from 'react';
import {
  type Node,
  type OnNodesChange,
  applyNodeChanges,
  getConnectedEdges,
  Edge,
} from 'reactflow';

import Y from 'yjs';

// We are using nodesMap as the one source of truth for the nodes.
// This means that we are doing all changes to the nodes in the map object.
// Whenever the map changes, we update the nodes state.

function useNodesStateSynced(
  ydoc: Y.Doc,
): [Node[], React.Dispatch<React.SetStateAction<Node[]>>, OnNodesChange] {
  const [nodes, setNodes] = useState<Node[]>([]);
  const nodesMapRef = useRef(ydoc.getMap<Node>('nodes'));
  const edgesMapRef = useRef(ydoc.getMap<Edge>('edges'));

  const setNodesSynced = useCallback(
    (nodesOrUpdater: React.SetStateAction<Node[]>) => {
      ydoc.transact(() => {
        const nodesMap = nodesMapRef.current;
        const seen = new Set<string>();
        const next =
          typeof nodesOrUpdater === 'function'
            ? nodesOrUpdater(Array.from(nodesMap.values())) // Convert to array here
            : nodesOrUpdater;

        for (const node of next) {
          seen.add(node.id);
          nodesMap.set(node.id, node);
        }

        for (const node of Array.from(nodesMap.values())) {
          if (!seen.has(node.id)) {
            nodesMap.delete(node.id);
          }
        }
      });
    },
    [ydoc],
  );

  // The onNodesChange callback updates nodesMap.
  // When the changes are applied to the map, the observer will be triggered and updates the nodes state.
  const onNodesChanges: OnNodesChange = useCallback(
    changes => {
      ydoc.transact(() => {
        const nodesMap = nodesMapRef.current;
        const edgesMap = edgesMapRef.current;
        const nodes = Array.from(nodesMap.values());
        const nextNodes = applyNodeChanges(changes, nodes as Node[]);

        for (const change of changes) {
          if (change.type === 'add' || change.type === 'reset') {
            nodesMap.set(change.item.id, change.item);
          } else if (change.type === 'remove' && nodesMap.has(change.id)) {
            const deletedNode = nodesMap.get(change.id)!;
            const connectedEdges = getConnectedEdges(
              [deletedNode],
              Array.from(edgesMap.values()) as Edge[],
            );

            nodesMap.delete(change.id);

            for (const edge of connectedEdges) {
              edgesMap.delete(edge.id);
            }
          } else {
            nodesMap.set(change.id, nextNodes.find(n => n.id === change.id)!);
          }
        }
      });
    },
    [ydoc],
  );

  // here we are observing the nodesMap and updating the nodes state whenever the map changes.
  useEffect(() => {
    const nodesMap = nodesMapRef.current;

    const observer = () => {
      setNodes(Array.from(nodesMap.values()));
    };

    nodesMap.observe(observer);

    return () => {
      nodesMap.unobserve(observer);
    };
  }, []);

  return [nodes, setNodesSynced, onNodesChanges];
}

export default useNodesStateSynced;
