import argparse

from rdkit import Chem, rdBase

rdBase.DisableLog('rdApp.error')
rdBase.DisableLog('rdApp.warning')


class Evaluator(object):
    def __init__(self, isomeric=True):
        self.isomeric = isomeric

        self.results = []

        self._match = "MATCH"
        self._mistake = "MISTAKE"
        self._miss = "MISS"
        self._error = "ERROR"
        self.result_labels = [
            self._match,
            self._mistake,
            self._miss,
            self._error,
        ]

    def __call__(self, hypo, ref):

        result = EvalResult()

        label, canonical = self._determine_identity(
            ref, hypo)

        result.set_result(label, hypo, canonical)

        self.results.append(result.get_result())

    def _determine_identity(self, ref, predicted):
        if predicted == self._miss:
            return self._miss, None

        label = None
        pred_mol = Chem.MolFromSmiles(predicted)
        label, canonical = (
            self._check_smiles_match(ref, pred_mol)
            if pred_mol is not None
            else (self._error, None)
        )

        return label, canonical

    def _check_smiles_match(self, ref_smiles, pred_mol):
        canonical_pred = (
            Chem.MolToSmiles(
                pred_mol, isomericSmiles=True if self.isomeric else False)
            if pred_mol is not None
            else None
        )
        if canonical_pred is not None and ref_smiles == canonical_pred:
            return self._match, canonical_pred
        return self._mistake, canonical_pred

    def show(self, args):
        n_entire = len(self.results)
        n_matches = sum(r["label"] == self._match for r in self.results)
        n_mistakes = sum(r["label"] == self._mistake for r in self.results)
        n_misses = sum(r["label"] == self._miss for r in self.results)
        n_errors = sum(r["label"] == self._error for r in self.results)

        recall = round((n_matches / n_entire), 3)
        precision = round(
            (n_matches / (n_entire - n_misses - n_errors)), 3)
        f_measure = round(
            2 * ((precision * recall) / (precision + recall)), 3)
        validity = round(((n_matches + n_mistakes) / n_entire), 3)

        print(f"recall = {recall}")
        print(f"precision = {precision}")
        print(f"f-measure = {f_measure}")
        print(f"validity = {validity}")


class EvalResult(object):
    def __init__(self):
        pass

    def get_result(self):
        return self._result

    def set_result(self, label, predicted, canonical):
        self._result = {
            "label": label,
            "predicted": predicted,
            "canonical": canonical,
        }


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--pred", "-p", help="a file consisting of SMILES strings that a system predicted for the evaluation dataset, one per line")
    parser.add_argument(
        "--ref", "-r", help="a file consisting of correct SMILES strings for the evaluation dataset, one per line")

    args = parser.parse_args()

    with open(args.pred, "r") as f:
        hypos = f.read().replace(" ", "").splitlines()
    with open(args.ref, "r") as f:
        refs = f.read().splitlines()

    evaluator = Evaluator()

    for hypo, ref in zip(hypos, refs):
        evaluator(hypo, ref)

    evaluator.show(args)


if __name__ == "__main__":
    main()
