import { ApolloCache } from '@apollo/client'

import { useChartId } from '../providers/use-chart-uuid'
import { useTreeState } from '../providers/use-tree-state'

import { nodes as nodesQuery } from 'apollo/query'
import { getChartNormalizedId, getNodeNormalizedId } from 'apollo/helpers'
import { nodeDataFragment, chartFragment } from 'apollo/fragments'
import {
  NodesQuery,
  NodesQueryVariables,
  NodeDataFragmentFragment,
  ChartFragmentFragment,
} from 'apollo/generated/graphql'

type Action = 'increase' | 'decrease'
type PrevData = Partial<NodeDataFragmentFragment> | null

const uniqueFilter = <T>(v: T, i: number, s: T[]): unknown => s.indexOf(v) === i

const countPayload = (node: NodeDataFragmentFragment, mutationData: NodeDataFragmentFragment, action: Action) => {
  const shouldIncrease = action === 'increase'
  const term = shouldIncrease ? 1 : -1

  const employeeCount =
    node.employeeCount +
    term * (mutationData.__typename === 'Person' ? mutationData.employeeCount + 1 : mutationData.employeeCount)
  const departmentCount =
    node.departmentCount +
    term * (mutationData.__typename === 'Department' ? mutationData.departmentCount + 1 : mutationData.departmentCount)

  return { ...node, employeeCount, departmentCount }
}

const chartCountPayload = (
  chart: ChartFragmentFragment,
  mutationData: NodeDataFragmentFragment,
  prevData?: PrevData
) => {
  const isDeleted = mutationData.deleted
  const isUnassigned = !isDeleted && mutationData.unassigned
  const wasUnassigned = prevData?.unassigned

  const isNodeNew = !prevData

  const shouldIncreaseCount = (wasUnassigned && !isUnassigned && !isDeleted) || isNodeNew
  const shouldDecreaseCount = (!wasUnassigned && isUnassigned) || (!wasUnassigned && isDeleted)

  const shouldIncreaseUnassignedCount = !wasUnassigned && isUnassigned && !isDeleted
  const shouldDecreaseUnassignedCount = (wasUnassigned && !isUnassigned) || (wasUnassigned && isDeleted)

  const term = shouldIncreaseCount ? 1 : shouldDecreaseCount ? -1 : 0

  const countPayload = {
    employeeCount:
      chart.employeeCount +
      term * (mutationData.__typename === 'Person' ? mutationData.employeeCount + 1 : mutationData.employeeCount),
    departmentCount:
      chart.departmentCount +
      term *
        (mutationData.__typename === 'Department' ? mutationData.departmentCount + 1 : mutationData.departmentCount),
  }

  const countObj = shouldIncreaseCount || shouldDecreaseCount ? countPayload : {}

  const unassignedTerm = shouldIncreaseUnassignedCount ? 1 : shouldDecreaseUnassignedCount ? -1 : 0

  const unassignedCountPayload = {
    unassignedEmployeeCount: Math.max(
      chart.unassignedEmployeeCount +
        unassignedTerm * (mutationData.__typename === 'Person' ? mutationData.employeeCount + 1 : 0),
      0
    ),
    unassignedDepartmentCount: Math.max(
      chart.unassignedDepartmentCount +
        unassignedTerm * (mutationData.__typename === 'Department' ? mutationData.departmentCount + 1 : 0),
      0
    ),
  }

  const unassignedCount = shouldIncreaseUnassignedCount || shouldDecreaseUnassignedCount ? unassignedCountPayload : {}

  return { ...chart, ...countObj, ...unassignedCount }
}

const treeIteratorFactory = (
  cache: ApolloCache<unknown>,
  chartUuid: string,
  parentUuids: string[] = [],
  iterateParentUuids = false
) => {
  return {
    [Symbol.iterator]: function* () {
      const visitedNodes: Record<string, boolean> = {}
      const queue = [...parentUuids]

      while (queue.length) {
        const parentUuid = queue.shift()
        if (!parentUuid) {
          continue
        }

        if (visitedNodes[parentUuid]) {
          continue
        }

        visitedNodes[parentUuid] = true

        if (iterateParentUuids) {
          for (const uuid of parentUuids) {
            const nodeId = getNodeNormalizedId(uuid, cache)
            if (nodeId) {
              const node = cache.readFragment<NodeDataFragmentFragment>({
                id: nodeId,
                fragment: nodeDataFragment,
                fragmentName: 'NodeDataFragment',
              })
              if (node) {
                yield node
              }
            }
          }
        }
        // Get nodes by parent uuid
        const nodes = cache.readQuery<NodesQuery, NodesQueryVariables>({
          query: nodesQuery,
          variables: { chartKey: chartUuid, filter: { parentUuid, unassigned: false } },
        })

        if (nodes) {
          for (const node of nodes.nodes.items) {
            yield node
            queue.push(node.uuid)
          }
        }
      }
    },
  }
}

const updateParentNodes = (
  cache: ApolloCache<unknown>,
  mutationData: NodeDataFragmentFragment,
  prevData?: PrevData
) => {
  const newParentNodes = mutationData?.parentNodes || []
  const oldParentNodesToRemoveCount = prevData?.parentNodes?.length ?? 0

  const treeIterator = treeIteratorFactory(cache, mutationData.chartUuid, [mutationData.uuid])

  for (const node of treeIterator) {
    const newParentNodesWithoutOldParentNodes = [...node.parentNodes]
    newParentNodesWithoutOldParentNodes.splice(0, oldParentNodesToRemoveCount)

    const nodeId = getNodeNormalizedId(node.uuid, cache)
    if (nodeId) {
      const parentNodes = [...newParentNodes, ...newParentNodesWithoutOldParentNodes].filter(uniqueFilter)
      cache.writeFragment<NodeDataFragmentFragment>({
        id: nodeId,
        fragment: nodeDataFragment,
        fragmentName: 'NodeDataFragment',
        data: { ...node, parentNodes },
      })
    }
  }
}

const updateChartCount = (cache: ApolloCache<unknown>, mutationData: NodeDataFragmentFragment, prevData?: PrevData) => {
  // Update company node
  const chartId = getChartNormalizedId(mutationData.chartUuid)
  const chart = cache.readFragment<ChartFragmentFragment>({
    id: chartId,
    fragment: chartFragment,
    fragmentName: 'ChartFragment',
  })

  if (chart) {
    cache.writeFragment<ChartFragmentFragment>({
      id: chartId,
      fragment: chartFragment,
      fragmentName: 'ChartFragment',
      data: chartCountPayload(chart, mutationData, prevData),
    })
  }
}

type UpdateCountInnerParams = {
  cache: ApolloCache<unknown>
  mutationData: NodeDataFragmentFragment
  path: string[]
  action: Action
}
const useUpdateCountInner = () => {
  const updateCountInner = ({ cache, mutationData, path, action }: UpdateCountInnerParams) => {
    if (path.length > 0) {
      const [uuid, ...nextPath] = path

      const oldParentId = getNodeNormalizedId(uuid, cache)
      if (oldParentId) {
        const oldParentNode = cache.readFragment<NodeDataFragmentFragment>({
          id: oldParentId,
          fragment: nodeDataFragment,
          fragmentName: 'NodeDataFragment',
        })

        if (oldParentNode) {
          cache.writeFragment<NodeDataFragmentFragment>({
            id: oldParentId,
            fragment: nodeDataFragment,
            fragmentName: 'NodeDataFragment',
            data: countPayload(oldParentNode, mutationData, action),
          })
        }
      }

      updateCountInner({ cache, mutationData, path: nextPath, action })
    }
  }

  return { updateCountInner }
}

type UpdateCountParams = {
  cache: ApolloCache<unknown>
  mutationData: NodeDataFragmentFragment
  action: Action
  prevData?: PrevData
}
export const useUpdateCount = () => {
  const { updateCountInner } = useUpdateCountInner()

  const updateCount = ({ cache, mutationData, action, prevData }: UpdateCountParams) => {
    updateChartCount(cache, mutationData, prevData)
    const path = mutationData.parentNodes.map(node => node.uuid).reverse()
    updateCountInner({ cache, mutationData, path, action })
  }

  return { updateCount }
}

type DeleteOrUnassignNodeParams = {
  cache: ApolloCache<unknown>
  mutationNodes: readonly NodeDataFragmentFragment[]
  prevData?: PrevData
  prevParentUuid: string | null
}
export const useDeleteOrUnassignNode = () => {
  const { updateCount } = useUpdateCount()
  const { treeStateRef: chartStateRef, replaceExpandedNodes } = useTreeState()
  const chartUuid = useChartId()

  const deleteOrUnassignNode = ({ cache, mutationNodes, prevData, prevParentUuid }: DeleteOrUnassignNodeParams) => {
    updateExpandedNodes({ mutationNodes, prevParentUuid, cache })
    mutationNodes.forEach(deletedNode => {
      updateCacheAndCounts({ cache, node: deletedNode, prevParentUuid, prevData })
    })
  }

  type UpdateExpandedNodesParams = Omit<DeleteOrUnassignNodeParams, 'prevData'>
  const updateExpandedNodes = ({ mutationNodes, prevParentUuid, cache }: UpdateExpandedNodesParams) => {
    // Update expanded nodes
    const unexpandUuidList: (string | null)[] = mutationNodes.map(({ uuid }) => uuid)
    const parentNode = cache.readQuery<NodesQuery, NodesQueryVariables>({
      query: nodesQuery,
      variables: { chartKey: chartUuid, filter: { parentUuid: prevParentUuid, unassigned: false } },
    })
    const parentNodeItems = parentNode?.nodes.items || []
    if (parentNodeItems.length === 1 && prevParentUuid) {
      // Add parent to list as it has no subordinates now
      unexpandUuidList.push(prevParentUuid)
    }
    const newExpandedNodes = chartStateRef.current.expandedNodes.filter(node => !unexpandUuidList.includes(node.uuid))
    replaceExpandedNodes(newExpandedNodes)
  }

  type UpdateCacheAndCountsParams = Omit<DeleteOrUnassignNodeParams, 'mutationNodes'> & {
    node: NodeDataFragmentFragment
  }
  const updateCacheAndCounts = ({ cache, node, prevData, prevParentUuid }: UpdateCacheAndCountsParams) => {
    const prevParentUuidOrNull = prevParentUuid || null
    const wasUnassigned = prevData?.unassigned

    // Update counts
    const oldParentId = getNodeNormalizedId(prevParentUuidOrNull, cache)
    if (oldParentId) {
      const oldParentNode = cache.readFragment<NodeDataFragmentFragment>({
        id: oldParentId,
        fragment: nodeDataFragment,
        fragmentName: 'NodeDataFragment',
      })

      if (oldParentNode) {
        const parentNodes = [...oldParentNode.parentNodes]
        parentNodes.push(oldParentNode)
        const mutationData = { ...node, parentNodes }
        updateCount({ cache, mutationData, action: 'decrease', prevData })
      }
    } else {
      updateCount({ cache, mutationData: node, action: 'decrease', prevData })
    }

    if (wasUnassigned) {
      // Remove node from nodes(chartUuid, unassigned: true) query
      const oldParentQuery = cache.readQuery<NodesQuery, NodesQueryVariables>({
        query: nodesQuery,
        variables: { chartKey: node.chartUuid, filter: { unassigned: true } },
      })

      if (oldParentQuery?.nodes.items) {
        const oldNodes = oldParentQuery.nodes.items
        const newNodes = oldNodes.filter(n => n.uuid !== node.uuid)

        if (newNodes.length !== oldNodes.length) {
          cache.writeQuery<NodesQuery, NodesQueryVariables>({
            query: nodesQuery,
            variables: { chartKey: node.chartUuid, filter: { unassigned: true } },
            data: { __typename: 'Query', nodes: { ...oldParentQuery.nodes, items: newNodes } },
          })
        }
      }
    } else {
      // Remove node from nodes(chartUuid, prevParentUuid) query
      const chartUuid = node.chartUuid
      const oldParentQuery = cache.readQuery<NodesQuery, NodesQueryVariables>({
        query: nodesQuery,
        variables: { chartKey: chartUuid, filter: { parentUuid: prevParentUuidOrNull, unassigned: false } },
      })

      if (oldParentQuery?.nodes.items) {
        const oldNodes = oldParentQuery.nodes.items
        const newNodes = oldNodes.filter(n => n.uuid !== node.uuid)

        if (newNodes.length !== oldNodes.length) {
          cache.writeQuery<NodesQuery, NodesQueryVariables>({
            query: nodesQuery,
            variables: { chartKey: chartUuid, filter: { parentUuid: prevParentUuidOrNull, unassigned: false } },
            data: { __typename: 'Query', nodes: { ...oldParentQuery.nodes, items: newNodes } },
          })
        }

        // Remove all children queries
        const treeIterator = treeIteratorFactory(
          cache,
          chartUuid,
          node.parentNodes.map(n => n.uuid),
          true
        )

        Array.from(treeIterator).forEach(node => {
          const parentUuid = node.uuid
          const nodes = cache.readQuery<NodesQuery, NodesQueryVariables>({
            query: nodesQuery,
            variables: { chartKey: chartUuid, filter: { parentUuid, unassigned: false } },
          })
          if (nodes?.nodes) {
            cache.writeQuery<NodesQuery, NodesQueryVariables>({
              query: nodesQuery,
              variables: { chartKey: chartUuid, filter: { parentUuid, unassigned: false } },
              data: { ...nodes, nodes: { ...nodes.nodes, items: [] } },
            })
          }
        })
      }
    }

    // Remove cached nodes(chartUuid, uuid) query
    cache.writeQuery<NodesQuery, NodesQueryVariables>({
      query: nodesQuery,
      variables: { chartKey: node.chartUuid, filter: { parentUuid: node.uuid, unassigned: false } },
      data: { __typename: 'Query', nodes: { __typename: 'NodeCollection', items: [] } },
    })

    // Delete node from cache
    if (node.deleted) {
      setTimeout(() => {
        // Delete in the next tick
        ;(cache as any).data.delete(node.uuid)
      })
      // Move node to nodes(unassigned: true) query
    } else if (node.unassigned) {
      const result = cache.readQuery<NodesQuery, NodesQueryVariables>({
        query: nodesQuery,
        variables: { chartKey: node.chartUuid, filter: { unassigned: true } },
      })

      if (result) {
        const nodes = [...result.nodes.items, node].filter(
          (
            node,
            index,
            self // Filter out duplicates
          ) => index === self.findIndex(n => n.uuid === node.uuid)
        )

        cache.writeQuery<NodesQuery, NodesQueryVariables>({
          query: nodesQuery,
          variables: { chartKey: node.chartUuid, filter: { unassigned: true } },
          data: { __typename: 'Query', nodes: { ...result.nodes, items: nodes } },
        })
      }
    }
  }

  return { deleteOrUnassignNode }
}

type UpdateTreeParams = {
  cache: ApolloCache<unknown>
  mutationData: NodeDataFragmentFragment
  prevParentUuid?: string | null | undefined
  nextParentUuid?: string | null | undefined
  prevData?: PrevData
}
export const useUpdateTree = () => {
  const { updateCount } = useUpdateCount()

  const updateTree = ({ cache, mutationData, prevParentUuid, nextParentUuid, prevData }: UpdateTreeParams) => {
    const prevParentUuidOrNull = prevParentUuid || null
    const nextParentUuidOrNull = nextParentUuid || null
    const wasUnassigned = prevData?.unassigned

    // Remove from unassigned
    if (wasUnassigned) {
      const result = cache.readQuery<NodesQuery, NodesQueryVariables>({
        query: nodesQuery,
        variables: { chartKey: mutationData.chartUuid, filter: { unassigned: true } },
      })

      if (result) {
        // Filter out node
        const nodes = [...result.nodes.items].filter(node => node.uuid !== mutationData.uuid)
        cache.writeQuery<NodesQuery, NodesQueryVariables>({
          query: nodesQuery,
          variables: { chartKey: mutationData.chartUuid, filter: { unassigned: true } },
          data: { __typename: 'Query', nodes: { ...result.nodes, items: nodes } },
        })
      }
    }

    // Remove node from the old branch
    const oldParentQuery = cache.readQuery<NodesQuery, NodesQueryVariables>({
      query: nodesQuery,
      variables: {
        chartKey: mutationData.chartUuid,
        filter: { parentUuid: prevParentUuidOrNull, unassigned: false },
      },
    })

    if (oldParentQuery?.nodes.items) {
      const nodes = oldParentQuery.nodes.items.filter(node => node.uuid !== mutationData.uuid) || []

      cache.writeQuery<NodesQuery, NodesQueryVariables>({
        query: nodesQuery,
        variables: {
          chartKey: mutationData.chartUuid,
          filter: { parentUuid: prevParentUuidOrNull, unassigned: false },
        },
        data: { __typename: 'Query', nodes: { ...oldParentQuery.nodes, items: nodes } },
      })
    }

    // Decrease counts on the old branch
    const oldParentId = getNodeNormalizedId(prevParentUuidOrNull, cache)
    if (oldParentId) {
      const oldParentNode = cache.readFragment<NodeDataFragmentFragment>({
        id: oldParentId,
        fragment: nodeDataFragment,
        fragmentName: 'NodeDataFragment',
      })

      if (oldParentNode) {
        const parentNodes = [...oldParentNode.parentNodes]
        parentNodes.push(oldParentNode)
        const data = { ...mutationData, parentNodes }
        updateCount({ cache, mutationData: data, action: 'decrease', prevData })
      }
    }

    // Add node to the new branch
    const newParentQuery = cache.readQuery<NodesQuery, NodesQueryVariables>({
      query: nodesQuery,
      variables: {
        chartKey: mutationData.chartUuid,
        filter: { parentUuid: nextParentUuidOrNull, unassigned: false },
      },
    })

    if (newParentQuery?.nodes.items) {
      const newNodes = [...newParentQuery.nodes.items, mutationData].filter(uniqueFilter)

      cache.writeQuery<NodesQuery, NodesQueryVariables>({
        query: nodesQuery,
        variables: {
          chartKey: mutationData.chartUuid,
          filter: { parentUuid: nextParentUuidOrNull, unassigned: false },
        },
        data: { __typename: 'Query', nodes: { ...newParentQuery.nodes, items: newNodes } },
      })
    }

    updateParentNodes(cache, mutationData, prevData)
    // Increase counts on the new branch
    updateCount({ cache, mutationData, action: 'increase', prevData })
  }

  return { updateTree }
}
