import { getUseCases } from 'app/api/Usecase';
import {
  SET_ACTIVE_USECASE_COUNT,
  SET_CONFUSION_MODEL,
  SET_CONFUSION_USE_CASE,
  SET_DRAWER_STATE,
  SET_DRAWER_STATUS,
  SET_DRAWER_USECASE,
  SET_MODE,
  SET_UNIT,
  SET_MISSCLASSIFICATION_IMAGES_ROW_IDS,
  DELETE_MISSCLASSIFICATION_IMAGES_ROW_IDS,
  SET_DEFECT_DISTRIBUTION_INDIVIDUAL_LOADING,
  SET_DEFECT_DISTRIBUTION_DATA,
  DEFECT_DISTRIBUTION_CONSTANTS,
  RESET_DEFECT_DISTRIBUTION_DATA
} from './constants';
import orderBy from 'lodash/orderBy';

import { sortDefectsWithUnkownDefect } from 'app/utils/helpers';

const DEFAULT_TOTAL_DEFECT = {
  defect_name: 'Total out of distribution labels',
  name: 'Total out of distribution labels',
  organization_defect_code: '',
  formatted_name: 'Total out of distribution labels',
  is_trained_defect: false,
  id: -2,
  isCumulativeRow: true
};

export function setDrawerState(payload) {
  return { type: SET_DRAWER_STATE, payload };
}

export function setDrawerStatus(payload) {
  return { type: SET_DRAWER_STATUS, payload };
}

export function setMode(payload) {
  return { type: SET_MODE, payload };
}

export function setUnit(payload) {
  return { type: SET_UNIT, payload };
}

export function setConfusionUsecase(payload) {
  return { type: SET_CONFUSION_USE_CASE, payload };
}

export function setConfusionModel(payload) {
  return { type: SET_CONFUSION_MODEL, payload };
}

export function setDrawerUsecase(payload) {
  return { type: SET_DRAWER_USECASE, payload };
}

export function setActiveUsecaseCount(subscriptionId) {
  return dispatch => {
    getUseCases({
      baseParams: {
        limit: 1,
        model_status__in: 'deployed_in_prod',
        subscription_id: subscriptionId
      }
    }).then(_ => {
      const payload = _.count || null;
      dispatch({ type: SET_ACTIVE_USECASE_COUNT, payload });
    });
  };
}

export function setMisclassificationImagesRowIds(payload) {
  return { type: SET_MISSCLASSIFICATION_IMAGES_ROW_IDS, payload };
}

export function deleteMisclassificationImagesRowIds(payload) {
  return { type: DELETE_MISSCLASSIFICATION_IMAGES_ROW_IDS, payload };
}

export function setDefectDistributionIndividualLoading(payload) {
  return { type: SET_DEFECT_DISTRIBUTION_INDIVIDUAL_LOADING, payload };
}

export function resetDefectDistributionData() {
  return { type: RESET_DEFECT_DISTRIBUTION_DATA };
}

const isMisclassificationPair = (gtDefect, aiDefect) => {
  return !(gtDefect === aiDefect || aiDefect === -1 || gtDefect === -1);
};

const getTotalMisclassificationCount = data => {
  return data.reduce((prev, curr) => {
    if (isMisclassificationPair(curr.gt_defect, curr.ai_defect)) {
      return prev + curr.count;
    }
    return prev;
  }, 0);
};

const get2dMatrix = confusionMatrix => {
  const default_row = {
    count: 0,
    automated_count: 0,
    audited_fileset_count: 0
  };
  const defects = getDefectList(confusionMatrix);

  const defectsMap = defects.reduce(
    (prev, curr, index) => ({
      ...prev,
      [curr]: index
    }),
    {}
  );
  const result = defects.map(() => Array(defects.length).fill(default_row));
  confusionMatrix.forEach(row => {
    result[defectsMap[row['ai_defect']]][defectsMap[row['gt_defect']]] = row;
  });
  return result;
};

const getDefectList = confusionMatrix => {
  const defects = [];
  confusionMatrix.forEach(row => {
    defects.push(row['gt_defect']);
    defects.push(row['ai_defect']);
  });
  return [...new Set(defects)];
};

export function setDefectBasedDistribution(payload) {
  const { confusion_matrix, defects } = payload;
  const confusion_matrix_2d = get2dMatrix(confusion_matrix);
  const defectsList = getDefectList(confusion_matrix);

  const result = confusion_matrix_2d.map((_, index) => {
    let tp = 0;
    let fp = 0;
    let fn = 0;
    let total_fileset_count = 0;
    let automated_fileset_count = 0;
    let audited_fileset_count = 0;
    let total_gt_defects = 0;
    let total_model_defects = 0;
    let manual_fileset_count = 0;

    confusion_matrix_2d.map((_, ai_index) => {
      if (ai_index === index) {
        tp += confusion_matrix_2d[index][ai_index]['count'];
      } else if (
        confusion_matrix_2d[index][ai_index]?.ai_defect !== -1 &&
        confusion_matrix_2d[index][ai_index]?.gt_defect !== -1
      ) {
        fp += confusion_matrix_2d[index][ai_index]['count'];
      }

      if (confusion_matrix_2d[index][ai_index]?.gt_defect !== -1) {
        total_gt_defects += confusion_matrix_2d[index][ai_index]['count'];
      }
    });

    confusion_matrix_2d.map((_, ai_index) => {
      if (ai_index !== index) {
        fn += confusion_matrix_2d[ai_index][index]['count'];
      }
      if (confusion_matrix_2d[ai_index][index]?.ai_defect !== -1) {
        total_model_defects += confusion_matrix_2d[ai_index][index]['count'];
      }
      total_fileset_count += confusion_matrix_2d[ai_index][index]['count'];
      if (confusion_matrix_2d[ai_index][index]?.ai_defect !== -1) {
        automated_fileset_count +=
          confusion_matrix_2d[ai_index][index]['count'];
      } else if (confusion_matrix_2d[ai_index][index]?.ai_defect === -1) {
        manual_fileset_count += confusion_matrix_2d[ai_index][index]['count'];
      }
      audited_fileset_count +=
        confusion_matrix_2d[ai_index][index]['audited_fileset_count'];
    });

    return {
      name: defects[defectsList[index]]?.name || '',
      defect_id: defectsList[index],
      total: total_fileset_count,
      auto: automated_fileset_count,
      manual: manual_fileset_count,
      audited: audited_fileset_count,
      correct: tp,
      missed: fn,
      extra: fp,
      recall_percentage: total_model_defects
        ? Math.round((tp * 100) / total_model_defects)
        : null,
      precision_percentage: total_gt_defects
        ? Math.round((tp * 100) / total_gt_defects)
        : null
    };
  });

  return {
    type: SET_DEFECT_DISTRIBUTION_DATA,
    payload: {
      type: DEFECT_DISTRIBUTION_CONSTANTS.DEFECT_BASED_DISTRIBUTION,
      data: result.filter(item => item.defect_id !== -1)
    }
  };
}

const getConfusionRank = data => {
  const contributionMap = {
    '0-50': [],
    '50-80': [],
    '80+': []
  };
  return data.map(({ misclassification_contribution_percent, ...rest }) => {
    if (!misclassification_contribution_percent) return rest;

    if (
      misclassification_contribution_percent <= 50 ||
      (misclassification_contribution_percent > 50 &&
        (contributionMap['0-50'][contributionMap['0-50'].length - 1] <= 50 ||
          contributionMap['0-50'].length === 0))
    ) {
      contributionMap['0-50'].push(misclassification_contribution_percent);
      return { ...rest, rank: 0 };
    }

    if (
      (misclassification_contribution_percent > 50 &&
        misclassification_contribution_percent <= 80) ||
      (misclassification_contribution_percent > 80 &&
        contributionMap['50-80'][contributionMap['50-80'].length - 1] < 80 &&
        contributionMap['50-80'][contributionMap['50-80'].length - 1] > 50)
    ) {
      contributionMap['50-80'].push(misclassification_contribution_percent);
      return { ...rest, rank: 1 };
    }

    if (misclassification_contribution_percent > 80) {
      contributionMap['80+'].push(misclassification_contribution_percent);
      return { ...rest, rank: 2 };
    }

    return rest;
  });
};

const getRankFromRawData = rawData => {
  const sortedData = orderBy(rawData, ['count'], ['desc']);
  const totalMisclassifications = getTotalMisclassificationCount(sortedData);

  let cumulativeSum = 0;
  const misclassificationContributionPercentage = sortedData.map(item => {
    let contributionPercentage = null;
    if (isMisclassificationPair(item.gt_defect, item.ai_defect)) {
      contributionPercentage =
        ((cumulativeSum + item.count) * 100) / totalMisclassifications;
      cumulativeSum += item.count;
    }

    return {
      ...item,
      misclassification_contribution_percent: contributionPercentage
    };
  });

  return getConfusionRank(misclassificationContributionPercentage);
};

const generateMatrixFromRawData = rawData => {
  const matrix = {};
  const aiLabelIds = [];
  const rawDataWithRanks = getRankFromRawData(rawData);
  rawDataWithRanks.forEach(item => {
    if (!matrix[item['gt_defect']]) {
      matrix[item['gt_defect']] = { ai_defects: {} };
    }
    matrix[item['gt_defect']]['ai_defects'][item['ai_defect']] = item;
    aiLabelIds.push(item['ai_defect']);
  });

  return matrix;
};

const generateMatrixMetaFromMatrix = (matrix, defects, rawData) => {
  const matrixMeta = {};
  const DEFAULT_META_VALUE = {
    recall: 0,
    precision: 0,
    model_count: 0,
    gt_count: 0
  };

  const outOfDistributionDefects = [];

  defects.forEach(item => {
    if (!matrixMeta[item.id]) {
      matrixMeta[item.id] = { ...DEFAULT_META_VALUE };
    }
    if (!item.is_trained_defect) outOfDistributionDefects.push(item);
  });

  // Cumulative Out of Distribution Data
  matrixMeta['-2'] = { ...DEFAULT_META_VALUE, recall: 'N/A', precision: 'N/A' };

  const unknownAiDefectMap = {};

  rawData.forEach(item => {
    if (item.gt_defect === -1) {
      unknownAiDefectMap[item.ai_defect] =
        (unknownAiDefectMap[item.ai_defect] || 0) + item.count;
    }
  });

  rawData.forEach(item => {
    const { model_count, ...rest } = matrixMeta[item.ai_defect];
    matrixMeta[item.ai_defect] = {
      ...rest,
      model_count: model_count + item.count
    };
  });

  Object.keys(matrix).forEach(item => {
    let unknownGTDefectCount = 0;
    matrixMeta[item].gt_count = Object.values(matrix[item].ai_defects).reduce(
      (prev, curr) => {
        if (curr.ai_defect === -1) {
          unknownGTDefectCount += curr.count;
        }
        return prev + curr.count;
      },
      0
    );

    const actualGTCount = matrixMeta[item].gt_count - unknownGTDefectCount;

    matrixMeta[item].recall =
      actualGTCount && Number(item) !== -1
        ? ((matrix[item].ai_defects[item]?.count || 0) * 100) / actualGTCount
        : 'N/A';

    const actualModelCount =
      (matrixMeta[item].model_count || 0) - (unknownAiDefectMap[item] || 0);

    matrixMeta[item].precision =
      actualModelCount && Number(item) !== -1
        ? ((matrix[item].ai_defects[item]?.count || 0) * 100) / actualModelCount
        : 'N/A';
  });

  matrixMeta['-2'].gt_count = outOfDistributionDefects.reduce(
    (prev, curr) => prev + matrixMeta[curr.id].gt_count,
    0
  );

  return matrixMeta;
};

const getConfusionMatrix = (rawData, defects) => {
  const matrix = generateMatrixFromRawData(rawData);
  const matrixMeta = generateMatrixMetaFromMatrix(matrix, defects, rawData);

  return { matrix, matrixMeta };
};

export const sortNormalizedDefect = (order, defects) => {
  return sortDefectsWithUnkownDefect(
    defects,
    ['defect.organization_defect_code', 'defect.name'],
    order === 'asc'
  ).map(gtDefect => {
    return {
      ...gtDefect,
      ai_defects: sortDefectsWithUnkownDefect(
        gtDefect.ai_defects,
        ['defect.organization_defect_code', 'defect.name'],
        order === 'asc'
      )
    };
  });
};

const getTotalOutOfDistribution = data => {
  if (!data.length) return [];

  const { ai_defects } = data[0];

  const cumulativeAiDefects = [];

  ai_defects.forEach((_, aiDefectIndex) => {
    let totalCount = 0;

    data.forEach((_, dataIndex) => {
      totalCount += data[dataIndex]['ai_defects'][aiDefectIndex]?.count || 0;
    });
    const { defect } = data[0]['ai_defects'][aiDefectIndex];
    cumulativeAiDefects.push({ defect, count: totalCount });
  });

  return [
    {
      defect: DEFAULT_TOTAL_DEFECT,
      ai_defects: cumulativeAiDefects
    }
  ];
};

export const sortDefectsForMatrix = (order, defects, isCollapsed) => {
  const sortedInDistribution = sortNormalizedDefect(
    order,
    defects.filter(item => item.defect.is_trained_defect)
  );

  const sortedOutOfDistribution = sortNormalizedDefect(
    order,
    defects.filter(item => !item.defect.is_trained_defect)
  );

  if (isCollapsed && sortedOutOfDistribution.length) {
    return {
      inDistribution: sortedInDistribution,
      outOfDistribution: getTotalOutOfDistribution(sortedOutOfDistribution)
    };
  }
  return {
    inDistribution: sortedInDistribution,
    outOfDistribution: [
      ...getTotalOutOfDistribution(sortedOutOfDistribution),
      ...sortedOutOfDistribution
    ]
  };
};

export function setConfusionMatrics(payload) {
  const { confusion_matrix, defects } = payload;

  if (!confusion_matrix.length)
    return {
      type: SET_DEFECT_DISTRIBUTION_DATA,
      payload: {
        type: DEFECT_DISTRIBUTION_CONSTANTS.CONFUSION_MATRICS,
        data: { matrix: [], matrixMeta: {} }
      }
    };

  const inDistributionDefect = Object.values(defects).filter(
    item => item.is_trained_defect
  );
  const { matrix: formatedMatrix, matrixMeta } = getConfusionMatrix(
    confusion_matrix,
    Object.values(defects)
  );

  const result = [];
  Object.values(defects).forEach(defect => {
    let ai = {};
    const isTrainedDefect = defect.is_trained_defect;
    let isEmptyRow = !isTrainedDefect;
    if (formatedMatrix[defect.id]) {
      ai = { ...formatedMatrix[defect.id], ai_defects: [], defect };
      const originalPredictions = formatedMatrix[defect.id].ai_defects;
      inDistributionDefect.forEach(predicted => {
        if (!originalPredictions[predicted.id]) {
          ai.ai_defects.push({
            defect: predicted,
            count: 0,
            file_set_ids: []
          });
        } else {
          if (originalPredictions[predicted.id].count > 0) {
            isEmptyRow = false;
          }
          ai.ai_defects.push({
            defect: predicted,
            ...originalPredictions[predicted.id]
          });
        }
      });
    } else {
      ai = {
        ...ai,
        defect,
        gt_count: 0,
        ai_defects: inDistributionDefect.map(defect => ({
          defect,
          count: 0,
          file_set_ids: []
        }))
      };
    }
    if (!isEmptyRow) result.push(ai);
  });

  return {
    type: SET_DEFECT_DISTRIBUTION_DATA,
    payload: {
      type: DEFECT_DISTRIBUTION_CONSTANTS.CONFUSION_MATRICS,
      data: { matrix: result, matrixMeta }
    }
  };
}

export function setMisclassificationPairs(payload) {
  const { confusion_matrix, defects } = payload;

  const totalMisclassificationCount =
    getTotalMisclassificationCount(confusion_matrix);

  const result = confusion_matrix
    .reduce((prev, curr) => {
      if (isMisclassificationPair(curr.gt_defect, curr.ai_defect)) {
        return [
          ...prev,
          {
            ai_defect_id: curr['ai_defect'],
            gt_defect_id: curr['gt_defect'],
            ai_defect: defects[curr['ai_defect']] || {},
            gt_defect: defects[curr['gt_defect']] || {},
            misclassification_count: curr['count'],
            misclassification_percent:
              (100 * curr['count']) / totalMisclassificationCount
          }
        ];
      }

      return prev;
    }, [])
    .sort((a, b) => b.misclassification_count - a.misclassification_count);

  return {
    type: SET_DEFECT_DISTRIBUTION_DATA,
    payload: {
      type: DEFECT_DISTRIBUTION_CONSTANTS.MISCLASSIFICATION_PAIR,
      data: result
    }
  };
}
