import Immutable, { type Map, type OrderedSet, type Set } from 'immutable'

import { getNodeId } from './id-generator'
import {
  type BranchId,
  type JourneyNodeRecord,
  type YesNoNodeRecord,
  type RandomNodeRecord,
  type NodeError,
  type NodeType,
  YesNoNodeFactory,
  RandomNodeFactory,
  SplitBranchFactory,
  MessageNodeFactory,
  TimerNodeFactory,
  FinalNodeFactory,
  type TimerNodeRecord,
  type MessageNodeRecord,
  JourneyNodes,
  type FinalNodeRecord,
} from './journey.records'

import {
  type MessageConfigRecord,
  MessageConfigFactory,
} from 'com.batch/message/models/message.records'

export const getNodeById = ({
  nodes,
  nodeId,
}: {
  nodes: Map<string, JourneyNodeRecord>
  nodeId: string
}): JourneyNodeRecord => {
  const node = nodes.get(nodeId)
  if (!node) throw new Error(`Node ${nodeId} was not found`)
  return node
}

export const isYesNo = (node: JourneyNodeRecord): node is YesNoNodeRecord => node.type === 'YESNO'
export const isRandom = (node: JourneyNodeRecord): node is RandomNodeRecord =>
  node.type === 'RANDOM'
export const isMessage = (node: JourneyNodeRecord): node is MessageNodeRecord =>
  node.type === 'MESSAGE'
export const isTimer = (node: JourneyNodeRecord): node is TimerNodeRecord => node.type === 'TIMER'
export const isFinal = (node: JourneyNodeRecord): boolean => node.type === 'FINAL'

export function assertIsYesNode(node: Maybe<JourneyNodeRecord>): asserts node is YesNoNodeRecord {
  if (!node || node.type !== 'YESNO') throw new Error('Node is not a YESNO node')
}
export function assertIsRandomNode(
  node: Maybe<JourneyNodeRecord>
): asserts node is RandomNodeRecord {
  if (!node || node.type !== 'RANDOM') throw new Error('Node is not a RANDOM node')
}
export function assertIsMessageNode(
  node: null | JourneyNodeRecord | undefined
): asserts node is MessageNodeRecord {
  if (!node || !isMessage(node)) throw new Error('Node is not a MESSAGE node')
}
export function assertIsTimerNode(node: Maybe<JourneyNodeRecord>): asserts node is TimerNodeRecord {
  if (!node || node.type !== 'TIMER') throw new Error('Node is not a TIMER node')
}
export function assertIsFinalNode(node: Maybe<JourneyNodeRecord>): asserts node is FinalNodeRecord {
  if (!node || node.type !== 'FINAL') throw new Error('Node is not a FINAL node')
}
export const getParentsBranchId = ({
  nodes,
  nodeId,
}: {
  nodes: Map<string, JourneyNodeRecord>
  nodeId: string
}): Array<BranchId> => {
  const dest: Array<BranchId> = []
  nodes.forEach(node => {
    switch (node.type) {
      case 'FINAL':
        break
      case 'MESSAGE':
        if (node.nextNodeId === nodeId) dest.push({ type: 'MESSAGE', stepMessageNodeId: node.id })
        break
      case 'TIMER':
        if (node.nextNodeId === nodeId) dest.push({ type: 'TIMER-NEXT', timerNodeId: node.id })
        node.onEvents.forEach((oe, triggerIndex) => {
          if (oe.nextNodeId === nodeId)
            dest.push({
              type: 'TIMER-EVENT',
              timerNodeId: node.id,
              triggerIndex,
            })
        })
        break
      case 'RANDOM': {
        node.splits.forEach((split, index) => {
          if (split.nextNodeId === nodeId)
            dest.push({
              type: 'RANDOM',
              randomNodeId: node.id,
              splitIndex: index,
            })
        })
        break
      }
      case 'YESNO':
        if (node.yesNodeId === nodeId)
          dest.push({ type: 'YESNO', yesNoNodeId: node.id, branch: 'yes' })
        if (node.noNodeId === nodeId)
          dest.push({ type: 'YESNO', yesNoNodeId: node.id, branch: 'no' })
    }
  })
  return dest.length === 0 ? [{ type: 'ROOT' }] : dest
}

export const countNodeType = (
  nodes: Map<string, JourneyNodeRecord> | OrderedSet<JourneyNodeRecord>,
  type: NodeType
): number => {
  return Immutable.OrderedSet.isOrderedSet(nodes)
    ? nodes.filter(node => node.type === type).size
    : nodes.filter(node => node.type === type).size
}

export const getAllNodeBranchIds = (node: JourneyNodeRecord): Array<BranchId> => {
  switch (node.type) {
    case 'FINAL':
      return []
    case 'MESSAGE':
      return [{ type: 'MESSAGE', stepMessageNodeId: node.id }]
    case 'TIMER':
      return [
        { type: 'TIMER-NEXT', timerNodeId: node.id },
        ...node.onEvents.toArray().map((_, triggerIndex) => {
          return {
            triggerIndex,
            type: 'TIMER-EVENT',
            timerNodeId: node.id,
          } as BranchId
        }),
      ]
    case 'YESNO':
      return [
        { type: 'YESNO', yesNoNodeId: node.id, branch: 'yes' },
        { type: 'YESNO', yesNoNodeId: node.id, branch: 'no' },
      ]
    case 'RANDOM':
      return node.splits.toArray().map((_, index) => {
        return { type: 'RANDOM', randomNodeId: node.id, splitIndex: index }
      })
  }
}

/*
  walk though a branch and return all descendant
*/

export const getAllDescendantsForBranch = ({
  nodesMap,
  branchId,
  result = Immutable.OrderedSet(),
}: {
  nodesMap: Map<string, JourneyNodeRecord>
  branchId: BranchId
  result?: OrderedSet<JourneyNodeRecord>
}): OrderedSet<JourneyNodeRecord> => {
  let nodesToAdd: Array<JourneyNodeRecord> = []
  let next: JourneyNodeRecord | null | undefined = null
  switch (branchId.type) {
    case 'ROOT':
      return Immutable.OrderedSet(nodesMap.values())
    case 'MESSAGE': {
      const messageNode = nodesMap.get(branchId.stepMessageNodeId)
      if (!messageNode) throw new Error(`Message node ${branchId.stepMessageNodeId} was not found`)
      if (messageNode.type !== 'MESSAGE')
        throw new Error(`Node ${branchId.stepMessageNodeId} is not a message node`)
      next = nodesMap.get(messageNode.nextNodeId)
      if (!next) throw new Error(`Next node ${messageNode.nextNodeId} was not found`)
      break
    }
    case 'TIMER-NEXT': {
      const timerNode = nodesMap.get(branchId.timerNodeId)
      if (!timerNode) throw new Error(`Timer node ${branchId.timerNodeId} was not found`)
      if (timerNode.type !== 'TIMER')
        throw new Error(`Node ${branchId.timerNodeId} is not a timer node`)
      next = nodesMap.get(timerNode.nextNodeId)
      if (!next) throw new Error(`Next node ${timerNode.nextNodeId} was not found`)
      break
    }
    case 'TIMER-EVENT': {
      const timerNode = nodesMap.get(branchId.timerNodeId)
      if (!timerNode) throw new Error(`Timer node ${branchId.timerNodeId} was not found`)
      if (timerNode.type !== 'TIMER')
        throw new Error(`Node ${branchId.timerNodeId} is not a timer node`)
      const nextBranchId = timerNode.onEvents.get(branchId.triggerIndex)?.nextNodeId
      if (!nextBranchId)
        throw new Error(`Next onEvents branch for ${branchId.triggerIndex} was not found`)
      next = nodesMap.get(nextBranchId)
      if (!next) throw new Error(`Next node ${nextBranchId} was not found`)
      break
    }
    case 'YESNO': {
      const yesNoNode = nodesMap.get(branchId.yesNoNodeId)
      if (!yesNoNode) throw new Error(`YesNo node ${branchId.yesNoNodeId} was not found`)
      if (yesNoNode.type !== 'YESNO')
        throw new Error(`Node ${branchId.yesNoNodeId} is not a yesNo node`)
      next = nodesMap.get(branchId.branch === 'yes' ? yesNoNode.yesNodeId : yesNoNode.noNodeId)
      if (!next)
        throw new Error(
          `${branchId.branch.toUpperCase()} branch for node ${yesNoNode.id} was not found`
        )
      break
    }
    case 'RANDOM': {
      const randomNode = nodesMap.get(branchId.randomNodeId)
      if (!randomNode) throw new Error(`Random node ${branchId.randomNodeId} was not found`)
      if (randomNode.type !== 'RANDOM')
        throw new Error(`Node ${branchId.randomNodeId} is not a random node`)
      const nextNodeId = randomNode.splits.get(branchId.splitIndex, SplitBranchFactory()).nextNodeId
      next = nodesMap.get(nextNodeId ? nextNodeId : 'MISSING LINK')
      if (!next)
        throw new Error(
          `splits[${branchId.splitIndex}] branch for node ${randomNode.id} was not found`
        )
      break
    }
  }
  if (!next) return result
  nodesToAdd = [next]
  if (next.type !== 'FINAL')
    getAllNodeBranchIds(next).forEach(dest => {
      nodesToAdd = [
        ...nodesToAdd,
        ...getAllDescendantsForBranch({ nodesMap, branchId: dest, result }).toArray(),
      ]
    })
  return result.merge(nodesToAdd)
}

export const findCommonNode = ({
  node,
  nodesMap,
}: {
  node: JourneyNodeRecord
  nodesMap: Map<string, JourneyNodeRecord>
}): JourneyNodeRecord | null | undefined => {
  const branchIds = getAllNodeBranchIds(node)
  const descendants = branchIds.map(branchId =>
    getAllDescendantsForBranch({
      nodesMap,
      branchId,
    })
  )
  const commonDescendants = descendants.reduce((acc, descendantsForBranch) => {
    const commons: OrderedSet<JourneyNodeRecord> = acc.intersect(descendantsForBranch)
    return commons
  }, descendants[0])
  return commonDescendants.first()
}
/*
  split nodes only (random / yesno)
  returns nodes of a specific branch (skip rejoin)

  used to count message nodes per branch on remove node modal
*/
export const getBranchSplitNodes = ({
  nodesMap,
  branchId,
}: {
  nodesMap: Map<string, JourneyNodeRecord>
  branchId: BranchId
}): OrderedSet<JourneyNodeRecord> => {
  switch (branchId.type) {
    case 'ROOT':
    case 'MESSAGE':
    case 'TIMER-NEXT':
    case 'TIMER-EVENT':
      throw new Error('function shall be used for yesno and random branches only')
    case 'YESNO': {
      const yesNoNode = nodesMap.get(branchId.yesNoNodeId)
      if (!yesNoNode) throw new Error(`YesNo node ${branchId.yesNoNodeId} was not found`)
      if (yesNoNode.type !== 'YESNO')
        throw new Error(`Node ${branchId.yesNoNodeId} is not a yesNo node`)
      const orderedDescendants = getAllDescendantsForBranch({ nodesMap, branchId })
      const rejoinNode = findCommonNode({ nodesMap, node: yesNoNode })
      let found = false
      const branchNodes: Array<JourneyNodeRecord> = []
      orderedDescendants.forEach(node => {
        if (node === rejoinNode) {
          found = true
        } else {
          if (!found) branchNodes.push(node)
        }
      })
      return Immutable.OrderedSet(branchNodes)
    }
    case 'REJOIN':
      throw new Error('rejoin is not a real branch id')
    case 'RANDOM': {
      const randomNode = nodesMap.get(branchId.randomNodeId)
      if (!randomNode) throw new Error(`Random node ${branchId.randomNodeId} was not found`)
      if (randomNode.type !== 'RANDOM')
        throw new Error(`Node ${branchId.randomNodeId} is not a random node`)
      const orderedDescendants = getAllDescendantsForBranch({ nodesMap, branchId })
      const rejoinNode = findCommonNode({ nodesMap, node: randomNode })
      let found = false
      const branchNodes: Array<JourneyNodeRecord> = []
      orderedDescendants.forEach(node => {
        if (node === rejoinNode) {
          found = true
        } else {
          if (!found) branchNodes.push(node)
        }
      })
      return Immutable.OrderedSet(branchNodes)
    }
  }
}

export const countMessageNodesForSplitBranch = (
  nodesMap: Map<string, JourneyNodeRecord>,
  branchId: BranchId
): number => {
  return countNodeType(
    getBranchSplitNodes({
      nodesMap,
      branchId,
    }),
    'MESSAGE'
  )
}

const validateNode = (node: JourneyNodeRecord): JourneyNodeRecord => {
  const errors: Array<NodeError> = node.errors.toArray()
  switch (node.type) {
    case 'FINAL':
      break
    case 'MESSAGE':
      break
    case 'TIMER':
      if (node.mode !== 'until' && !node.timer.valid) errors.push('INVALID_TIMER')
      if (!node.timerReference && (node.mode === 'before' || node.mode === 'after'))
        errors.push('INVALID_TIMER')
      break
    case 'YESNO':
      if (node.noNodeId === node.yesNodeId) errors.push('SPLIT_EMPTY_BRANCHES')
      break
    case 'RANDOM':
      if (node.splits.map(s => s.nextNodeId).toSet().size === 1) errors.push('SPLIT_EMPTY_BRANCHES')
      if (node.splits.reduce((acc, split) => acc + split.weight, 0) !== 100)
        errors.push('SPLIT_RANDOM_DISTRIBUTION')
      break
  }

  return node.type === 'FINAL'
    ? node.set('errors', Immutable.Set(errors))
    : node.type === 'MESSAGE' // useless ternary, but flow is not smart enough
      ? node.set('errors', Immutable.Set(errors))
      : node.type === 'YESNO'
        ? node.set('errors', Immutable.Set(errors))
        : node.type === 'RANDOM'
          ? node.set('errors', Immutable.Set(errors))
          : node.type === 'TIMER'
            ? node.set('errors', Immutable.Set(errors))
            : node
}
const filterError = (errors: Set<NodeError>) =>
  errors.filter(
    error =>
      ![
        'INVALID_TIMER',
        'INVALID_END',
        'SPLIT_EMPTY_BRANCHES',
        'SPLIT_RANDOM_DISTRIBUTION',
      ].includes(error)
  )

const removeChainError = (
  nodes: Map<string, JourneyNodeRecord>
): Map<string, JourneyNodeRecord> => {
  return nodes.map(node =>
    node.type === 'FINAL'
      ? node.set('errors', filterError(node.errors))
      : node.type === 'MESSAGE' // useless ternary, but flow is not smart enough
        ? node.set('errors', filterError(node.errors))
        : node.type === 'YESNO'
          ? node.set('errors', filterError(node.errors))
          : node.type === 'TIMER'
            ? node.set('errors', filterError(node.errors))
            : node.type === 'RANDOM'
              ? node.set('errors', filterError(node.errors))
              : node
  )
}

export const validateTree = ({
  nodesMap,
}: {
  nodesMap: Map<string, JourneyNodeRecord>
}): Map<string, JourneyNodeRecord> => {
  const cleaned = removeChainError(nodesMap)
  return cleaned.map(node => validateNode(node))
}

const insertNodeAfterYesNo = ({
  yesno,
  isYes,
  id,
  messageConfig,
  nodeType,
  nodes,
}: {
  yesno: YesNoNodeRecord
  isYes: boolean
  messageConfig: MessageConfigRecord | null | undefined
  id: string
  nodeType: NodeType
  nodes: Map<string, JourneyNodeRecord>
}): Map<string, JourneyNodeRecord> => {
  const nextNodeId = yesno.get(isYes ? 'yesNodeId' : 'noNodeId')
  const newNode = buildNewNodeForNodeType({ id, nodeType, nextNodeId, messageConfig })
  const updatedYesNo = yesno.set(isYes ? 'yesNodeId' : 'noNodeId', id)
  return nodes.set(updatedYesNo.id, updatedYesNo).set(id, newNode)
}

const insertNodeAfterRandom = ({
  split,
  index,
  id,
  nodeType,
  nodes,
  messageConfig,
}: {
  split: RandomNodeRecord
  index: number
  messageConfig: MessageConfigRecord | null | undefined
  id: string
  nodeType: NodeType
  nodes: Map<string, JourneyNodeRecord>
}): Map<string, JourneyNodeRecord> => {
  const nextNodeId = split.splits.get(index)?.nextNodeId ?? 'NOTFOUND'
  const newNode = buildNewNodeForNodeType({
    id,
    nodeType,
    nextNodeId: nextNodeId as string,
    messageConfig,
  })
  const updatedSplit = split.setIn(['splits', index, 'nextNodeId'], id)
  return nodes.set(updatedSplit.id, updatedSplit).set(id, newNode)
}
const insertNodeAfterMessage = ({
  message,
  id,
  messageConfig,
  nodeType,
  nodes,
}: {
  message: MessageNodeRecord
  id: string
  nodeType: NodeType
  messageConfig: MessageConfigRecord | null | undefined
  nodes: Map<string, JourneyNodeRecord>
}): Map<string, JourneyNodeRecord> => {
  const nextNodeId = message.nextNodeId
  const newNode = buildNewNodeForNodeType({ id, nodeType, nextNodeId, messageConfig })
  const updatedMessage = message.set('nextNodeId', newNode.id)
  return nodes.set(updatedMessage.id, updatedMessage).set(id, newNode)
}

const insertNodeAfterTimerNext = ({
  timer,
  id,
  nodeType,
  messageConfig,
  nodes,
}: {
  timer: TimerNodeRecord
  nodeType: NodeType
  id: string
  messageConfig: MessageConfigRecord | null | undefined
  nodes: Map<string, JourneyNodeRecord>
}): Map<string, JourneyNodeRecord> => {
  const nextNodeId = timer.nextNodeId
  const newNode = buildNewNodeForNodeType({ id, nodeType, nextNodeId, messageConfig })
  const updatedTimer = timer.set('nextNodeId', newNode.id)
  return nodes.set(updatedTimer.id, updatedTimer).set(id, newNode)
}

const buildNewNodeForNodeType = ({
  id,
  nodeType,
  nextNodeId,
  messageConfig,
}: {
  id: string
  nodeType: NodeType
  nextNodeId: string
  messageConfig: MessageConfigRecord | null | undefined
}): JourneyNodeRecord => {
  switch (nodeType) {
    case 'MESSAGE':
      return MessageNodeFactory({
        id,
        nextNodeId,
        messageConfig: messageConfig ?? MessageConfigFactory(),
      })

    case 'TIMER':
      return TimerNodeFactory({
        id,
        nextNodeId,
      })

    case 'YESNO': {
      return YesNoNodeFactory({
        id,
        errors: Immutable.Set(['MISSING_TARGETING']),
        yesNodeId: nextNodeId,
        noNodeId: nextNodeId,
      })
    }
    case 'RANDOM': {
      return RandomNodeFactory({
        id,
        splits: Immutable.List([
          SplitBranchFactory({
            nextNodeId,
            weight: 50,
          }),
          SplitBranchFactory({
            nextNodeId,
            weight: 50,
          }),
        ]),
      })
    }
    default:
      return FinalNodeFactory({ id })
  }
}

const insertNodeInBranch = ({
  nodes,
  rootId,
  branchId,
  messageConfig,
  isUpdate,
  id,
  nodeType,
}: {
  nodes: Map<string, JourneyNodeRecord>
  rootId: string
  branchId: BranchId
  nodeType: NodeType
  isUpdate: boolean
  id: string
  messageConfig: MessageConfigRecord | null | undefined
}): {
  rootId: string
  nodes: Map<string, JourneyNodeRecord>
} => {
  switch (branchId.type) {
    case 'ROOT': {
      if (!isUpdate) {
        const newNode = buildNewNodeForNodeType({
          id,
          nodeType,
          messageConfig: messageConfig ?? MessageConfigFactory(),
          nextNodeId: rootId,
        })
        return {
          rootId: id,
          nodes: nodes.set(id, newNode),
        }
      }
      throw new Error('Cannot insert after root twice')
    }

    case 'YESNO': {
      const yesNoNode = getNodeById({
        nodes,
        nodeId: branchId.yesNoNodeId,
      })
      if (yesNoNode.type !== 'YESNO')
        throw new Error(`Node ${branchId.yesNoNodeId} is not a YesNo node`)
      if (!isUpdate) {
        return {
          rootId,
          nodes: insertNodeAfterYesNo({
            id,
            nodeType,
            nodes,
            yesno: yesNoNode,
            isYes: branchId.branch === 'yes',
            messageConfig,
          }),
        }
      } else {
        return {
          rootId,
          nodes: nodes.set(
            yesNoNode.id,
            yesNoNode.set(branchId.branch === 'yes' ? 'yesNodeId' : 'noNodeId', id)
          ),
        }
      }
    }
    case 'REJOIN': {
      throw new Error('random rejoin is not a real branch id')
    }
    case 'RANDOM': {
      const randomNode = getNodeById({ nodes, nodeId: branchId.randomNodeId })
      if (randomNode.type !== 'RANDOM')
        throw new Error(`Node ${branchId.randomNodeId} is not a Split node`)
      if (!isUpdate) {
        return {
          rootId,
          nodes: insertNodeAfterRandom({
            id,
            index: branchId.splitIndex,
            nodes,
            nodeType,
            split: randomNode,
            messageConfig,
          }),
        }
      } else {
        return {
          rootId,
          nodes: nodes.set(
            randomNode.id,
            randomNode.setIn(['splits', branchId.splitIndex, 'nextNodeId'], id)
          ),
        }
      }
    }

    case 'MESSAGE': {
      const messageNode = getNodeById({ nodes, nodeId: branchId.stepMessageNodeId })
      if (messageNode.type !== 'MESSAGE')
        throw new Error(`Node ${branchId.stepMessageNodeId} is not a Message node`)
      if (!isUpdate) {
        return {
          rootId,
          nodes: insertNodeAfterMessage({
            id,
            nodes,
            nodeType,
            message: messageNode,
            messageConfig,
          }),
        }
      } else {
        return {
          rootId,
          nodes: nodes.set(messageNode.id, messageNode.set('nextNodeId', id)),
        }
      }
    }

    case 'TIMER-NEXT': {
      const timerNode = getNodeById({ nodes, nodeId: branchId.timerNodeId })
      if (timerNode.type !== 'TIMER') throw new Error(`Node ${branchId.timerNodeId} is not a Timer`)
      if (!isUpdate) {
        return {
          rootId,
          nodes: insertNodeAfterTimerNext({
            id,
            nodes,
            timer: timerNode,
            nodeType,
            messageConfig,
          }),
        }
      } else {
        return {
          rootId,
          nodes: nodes.set(timerNode.id, timerNode.set('nextNodeId', id)),
        }
      }
    }

    case 'TIMER-EVENT': {
      throw new Error('Not implemented yet')
    }
  }
}

export const insertNodeAndUpdateTree = ({
  branchIds,
  nodeType,
  id,
  nodes,
  rootId,
  messageConfig,
}: {
  nodes: Map<string, JourneyNodeRecord>
  rootId: string
  id: string
  branchIds: Array<BranchId>
  nodeType: NodeType
  messageConfig: MessageConfigRecord | null | undefined
}): {
  nodes: Map<string, JourneyNodeRecord>
  rootId: string
} => {
  // we use an immutable list here because array.slice is mutable and fucks redux action log
  let branchIdsList = Immutable.List(branchIds)
  const insertPass = insertNodeInBranch({
    id,
    nodes,
    rootId,
    isUpdate: false,
    branchId: branchIdsList.last(),
    messageConfig,
    nodeType,
  })
  branchIdsList = branchIdsList.pop()
  let endNodes = insertPass.nodes
  let endRootId = insertPass.rootId
  while (branchIdsList.size > 0) {
    const updatePass = insertNodeInBranch({
      id,
      nodes: endNodes,
      rootId: endRootId,
      isUpdate: true,
      branchId: branchIdsList.last(),
      messageConfig,
      nodeType,
    })
    branchIdsList = branchIdsList.pop()
    endNodes = updatePass.nodes
    endRootId = updatePass.rootId
  }
  return { nodes: endNodes, rootId: endRootId }
}
/*
  Deletes a node form the tree
    keep tells us what to keep from the delete node ; 
    if keep is null, we delete the node and all its descendants

    common uses cases : 
      - delete message, keeping its next node
      - delete timer, keeping its next node
      - delete a yesno, keeping its yes branch
      - delete a yesno, keeping its no branch
      - delete a yesno, trashing all branches (keep is undefined)

    future uses cases : 
      - delete a timer, keeping one of its onEvents branch
      - delete big chunck of a journey (timer or message, keep undefined)
 */
export const removeNodeAndUpdateTree = ({
  nodeToRemove,
  nodes,
  rootId,
  branchToKeep,
}: {
  nodeToRemove: JourneyNodeRecord
  nodes: Map<string, JourneyNodeRecord>
  rootId: string
  branchToKeep: BranchId | null | undefined
}): {
  nodes: Map<string, JourneyNodeRecord>
  rootId: string
} => {
  const ownerBranchIds: Array<BranchId> = getParentsBranchId({ nodes, nodeId: nodeToRemove.id })
  // first we update the tree so it links to the right nodes after deletion
  const nodeIdFromRemovedBranches: Array<string> = []

  // let's retreive the node that we replace the removed node
  let newNextNodeId = ''
  switch (branchToKeep?.type) {
    case undefined: {
      if (nodeToRemove.type === 'YESNO' || nodeToRemove.type === 'RANDOM') {
        // we trash all till common descendant
        const rejoinNode = findCommonNode({ nodesMap: nodes, node: nodeToRemove })
        if (!rejoinNode) throw new Error('No common node found for YESNO')
        getAllNodeBranchIds(nodeToRemove).forEach(dest => {
          let stopDelete = false
          getAllDescendantsForBranch({
            nodesMap: nodes,
            branchId: dest,
          }).forEach(orphanedNode => {
            if (orphanedNode === rejoinNode) stopDelete = true
            if (!stopDelete) nodeIdFromRemovedBranches.push(orphanedNode.id)
          })
        })
        newNextNodeId = rejoinNode.id
      } else {
        /*
          case not used for now : remove a non branching node and keep nothing
          we trash all descendants and add a new final node
        */
        const newFinalNode = JourneyNodes.Final({ id: getNodeId('FINAL') })
        nodes = nodes.set(newFinalNode.id, newFinalNode)
        newNextNodeId = newFinalNode.id
        getAllNodeBranchIds(nodeToRemove).forEach(dest => {
          getAllDescendantsForBranch({
            nodesMap: nodes,
            branchId: dest,
          }).forEach(orphanedNode => {
            nodeIdFromRemovedBranches.push(orphanedNode.id)
          })
        })
      }
      break
    }
    case 'ROOT': {
      throw new Error('ROOT is not a valid keep branchId')
    }
    case 'MESSAGE': {
      if (nodeToRemove.type !== 'MESSAGE' || nodeToRemove.id !== branchToKeep.stepMessageNodeId) {
        throw new Error('Invalid keep branchId MESSAGE ' + branchToKeep.stepMessageNodeId)
      }
      newNextNodeId = nodeToRemove.nextNodeId
      break
    }
    case 'TIMER-NEXT': {
      if (nodeToRemove.type !== 'TIMER' || nodeToRemove.id !== branchToKeep.timerNodeId) {
        throw new Error('Invalid keep branchId TIMER-NEXT ' + branchToKeep.timerNodeId)
      }
      newNextNodeId = nodeToRemove.nextNodeId
      /*
        we remove a timer and keep its next branch : all onEvents branch are now orphaned
        with current use, this will only remove the lonely final node (exit event)
      */
      getAllNodeBranchIds(nodeToRemove)
        .filter(dest => dest.type !== 'TIMER-NEXT')
        .forEach(dest => {
          getAllDescendantsForBranch({
            nodesMap: nodes,
            branchId: dest,
          }).map(orphanedNode => {
            nodeIdFromRemovedBranches.push(orphanedNode.id)
          })
        })
      break
    }
    case 'TIMER-EVENT': {
      if (nodeToRemove.type !== 'TIMER' || nodeToRemove.id !== branchToKeep.timerNodeId) {
        throw new Error('Invalid keep branchId TIMER-NEXT ' + branchToKeep.timerNodeId)
      }
      const foundId = nodeToRemove.onEvents.get(branchToKeep.triggerIndex)?.nextNodeId
      if (!foundId) {
        throw new Error(
          'Invalid index ' +
            branchToKeep.triggerIndex +
            ' for keep branchId TIMER-EVENT ' +
            branchToKeep.timerNodeId
        )
      }
      newNextNodeId = foundId
      // we remove a timer and keep one of its onEvents branch : all other onEvents branch and the next branch are now orphaned
      getAllNodeBranchIds(nodeToRemove)
        .filter(
          dest => dest.type !== 'TIMER-EVENT' || dest.triggerIndex !== branchToKeep.triggerIndex
        )
        .forEach(dest => {
          getAllDescendantsForBranch({
            nodesMap: nodes,
            branchId: dest,
          }).map(orphanedNode => {
            nodeIdFromRemovedBranches.push(orphanedNode.id)
          })
        })
      break
    }
    case 'YESNO': {
      if (nodeToRemove.type !== 'YESNO' || nodeToRemove.id !== branchToKeep.yesNoNodeId) {
        throw new Error('Invalid keep branchId YESNO ' + branchToKeep.yesNoNodeId)
      }
      // we remove a yesno, we need to trash all descendants of the removed branch
      const rejoinNode = findCommonNode({ nodesMap: nodes, node: nodeToRemove })
      let stopDelete = false
      /*
          getAllDescendantsForBranch walks down the whole tree ; we need to delete until the common
          descendant node of our yes & no branches
        */
      getAllDescendantsForBranch({
        nodesMap: nodes,
        branchId: {
          type: 'YESNO',
          yesNoNodeId: nodeToRemove.id,
          branch: branchToKeep.branch === 'yes' ? 'no' : 'yes',
        },
      }).forEach(orphanedNode => {
        if (orphanedNode === rejoinNode) stopDelete = true
        if (!stopDelete) nodeIdFromRemovedBranches.push(orphanedNode.id)
      })
      // since we removed the yesNo node and one branch, the nextNodeId shall be the id of the branch we keep
      newNextNodeId = branchToKeep.branch === 'yes' ? nodeToRemove.yesNodeId : nodeToRemove.noNodeId
      break
    }
    case 'RANDOM': {
      if (nodeToRemove.type !== 'RANDOM' || nodeToRemove.id !== branchToKeep.randomNodeId) {
        throw new Error('Invalid keep branchId RANDOM ' + branchToKeep.randomNodeId)
      }
      // we remove a random, we need to trash all descendants of the removed branches
      const rejoinNode = findCommonNode({ nodesMap: nodes, node: nodeToRemove })

      const branches = getAllNodeBranchIds(nodeToRemove)
      branches
        .filter(b => b.type === 'RANDOM' && b.splitIndex !== branchToKeep.splitIndex)
        .forEach(branchId => {
          let stopDelete = false
          getAllDescendantsForBranch({
            nodesMap: nodes,
            branchId,
          }).forEach(orphanedNode => {
            if (orphanedNode === rejoinNode) stopDelete = true
            if (!stopDelete) nodeIdFromRemovedBranches.push(orphanedNode.id)
          })
        })
      newNextNodeId = nodeToRemove.splits.get(branchToKeep.splitIndex)?.nextNodeId ?? 'MISSING LINK'
      break
    }
  }
  // update all branchId in nodes owning the node to be deleted
  ownerBranchIds.forEach(branchId => {
    switch (branchId.type) {
      case 'ROOT': {
        rootId = newNextNodeId
        break
      }
      case 'MESSAGE': {
        const node = getNodeById({
          nodes,
          nodeId: branchId.stepMessageNodeId,
        })
        if (!node || node.type !== 'MESSAGE') throw new Error('MESSAGE not found')
        nodes = nodes.set(node.id, node.set('nextNodeId', newNextNodeId))
        break
      }
      case 'TIMER-NEXT': {
        const node = getNodeById({
          nodes,
          nodeId: branchId.timerNodeId,
        })
        if (!node || node.type !== 'TIMER') throw new Error('TIMER not found')
        nodes = nodes.set(node.id, node.set('nextNodeId', newNextNodeId))
        break
      }
      case 'TIMER-EVENT': {
        const node = getNodeById({
          nodes,
          nodeId: branchId.timerNodeId,
        })
        if (!node || node.type !== 'TIMER') throw new Error('TIMER not found')
        nodes = nodes.set(
          node.id,
          node.setIn(['onEvents', branchId.triggerIndex, 'nextNodeId'], newNextNodeId)
        )
        break
      }
      case 'RANDOM': {
        const node = getNodeById({
          nodes,
          nodeId: branchId.randomNodeId,
        })
        if (!node || node.type !== 'RANDOM') throw new Error('RANDOM not found')
        nodes = nodes.set(
          node.id,
          node.setIn(['splits', branchId.splitIndex, 'nextNodeId'], newNextNodeId)
        )
        break
      }
      case 'YESNO': {
        const node = getNodeById({
          nodes,
          nodeId: branchId.yesNoNodeId,
        })
        if (!node || node.type !== 'YESNO') throw new Error('YESNO not found')
        nodes = nodes.set(
          node.id,
          node.set(branchId.branch === 'yes' ? 'yesNodeId' : 'noNodeId', newNextNodeId)
        )
      }
    }
  })
  const allNodesToRemove = [...nodeIdFromRemovedBranches, nodeToRemove.id]
  return {
    nodes: nodes.filter(node => !allNodesToRemove.includes(node.id)),
    rootId,
  }
}

export const removeOrphanNodes = ({
  nodesMap,
  rootId,
}: {
  nodesMap: Map<string, JourneyNodeRecord>
  rootId: string
}): Map<string, JourneyNodeRecord> => {
  const orphansId: Array<string> = []
  nodesMap.forEach(node => {
    // no need to look for parents, root node is not orphaned
    if (node.id === rootId) return
    const parents = getParentsBranchId({ nodes: nodesMap, nodeId: node.id })
    if (parents.length === 0 || parents[0].type === 'ROOT') orphansId.push(node.id)
  })
  // redo a pass til we found nothing
  return orphansId.length === 0
    ? nodesMap
    : removeOrphanNodes({
        nodesMap: nodesMap.filter(node => !orphansId.includes(node.id)),
        rootId,
      })
}
