from collections import deque
from typing import List, Union, Optional
from itertools import groupby

# Same type aliases
Lit = int
Var = int
Clause = List[Lit]
CNF = List[Clause]
Assignment = List[Optional[bool]]


n = 2
cnf = [[-1, 2], [-1, -2]]


def preprocess(cnf: CNF) -> CNF:
    """ Remove duplicate literals and clauses from a CNF formula. """
    cnf = [list(set(clause)) for clause in cnf]
    cnf.sort()
    return list(clause for clause, _ in groupby(cnf))


class PennSAT:
    """A DPLL-based SAT solver with a static most-frequent decision heuristic."""

    def __init__(self, n: int, cnf: CNF, heuristic: bool = True, unit_propagation: bool = True):
        # The number of variables
        self.n = n
        # The CNF as a list of clauses
        self.cnf: CNF = preprocess(cnf)
        # A stack of partial truth assignments: lists mapping each variable to True/False/None
        self.assignment_stack: List[Assignment] = [[None] * (n + 1)]
        if heuristic:
            # We'll implement this function later
            self.var_ordering = self.compute_frequency_ordering()
        else:
            self.var_ordering = list(range(1, n + 1))
        # A flag indicating whether the solver should do unit propagation
        self.unit_propagation = unit_propagation
        # A stack of decision variables
        self.decision_stack: List[Var] = []
        # A map from literal to a list of clauses containing that literal
        self.clauses_containing: List[List[Clause]] = [list() for _ in range(2 * n + 1)]
        # A queue of literals which have been assumed `False`
        self.propagation_queue: deque = deque()

        for clause in self.cnf:
            for literal in clause:
                self.clauses_containing[literal].append(clause)

    def value(self, literal: Lit) -> Optional[bool]:
        """Returns the value of the literal (`True`/`False`/`None`) under the current assignment."""

        value = self.assignment_stack[-1][abs(literal)]
        if value is None:
            return None
        return value if literal > 0 else not value

    def assume(self, literal: Lit) -> bool:
        """Assigns the literal to `True` in the current assignment. Returns `True` if the literal was
        successfully assumed (or if it was already `True`). Returns `False` if it was already False.
        
        Use the value() method to check the current assignment.
        """
        val = self.value(literal)
        var = abs(literal)
        if val is True:
            return True
        elif val is False:
            return False
        elif val is None:
            self.assignment_stack[-1][var] = (literal > 0)
            self.propagation_queue.append(-1 * literal)
            return True
    
    def decide(self, lit):
        self.assignment_stack.append(self.assignment_stack[-1][:])
        self.decision_stack.append(lit)
        # literal was none so is ok
        self.assume(lit)

    def compute_frequency_ordering(self) -> List[Var]:
        """Returns a list of variables `[1..n]` sorted by descending frequency score
        based on the stored CNF."""
        cnf = self.cnf
        n = self.n

        frequencies = [0] * (n + 1)
        for clause in cnf:
            for literal in clause:
                var = abs(literal)
                frequencies[var] += 1
        frequencies = list(enumerate(frequencies))
        sorted_frequencies = sorted(frequencies, reverse=True, key=(lambda x: x[1]))
        return [x[0] for x in sorted_frequencies[1:]]

    def pick_variable(self) -> Optional[Var]:
        """Returns the first unassigned variable in the stored variable ordering,
        or `None` if all variables have been assigned."""
        
        for var in self.var_ordering:
            if self.value(var) is None:
                return var
        return None

    def backtrack_and_assume_negation(self) -> None:
        """Undo the last decision and restore the previous partial assignment.
        Then assume the negation of the last decision variable.
        
        What must be done to the assignment stack, decision stack, and propagation queue?"""
        
        # restore assignent and decision
        self.assignment_stack.pop()
        new_assumption = self.decision_stack.pop() * -1

        # clear propagation up to the assumed literal
        while (len(self.propagation_queue) > 0):
            if self.propagation_queue.pop() == new_assumption:
                break
        
        # assume opposite
        self.assume(new_assumption)

    def unit_propagate_clause(self, clause: Clause) -> bool:
        """
        Accepts as input a clause, and performs unit propagation if the
        clause is unit (unsatisfied and containing 1 unassigned literal).
        Returns `False` if a conflict is detected (unsatisfied and 0
        unassigned literals); otherwise `True`.
        """

        unassigned = []
        satisfied = False
        for literal in clause:
            val = self.value(literal)
            if val is None:
                unassigned.append(literal)
            elif val is True:
                satisfied = True
        
        if not satisfied:
            l = len(unassigned)
            if (l == 0):
                return False
            elif (l == 1):
                self.assume(unassigned[0])
        return True            

    def unit_propagate(self) -> bool:
        """
        Calls `unit_propagate_clause` on the clauses containing literals popped from `propagation_queue`
        until the queue is empty. If we ever encounter a conflict, this returns `False`; otherwise `True`.
        If `self.unit_propagation` is `False`, simply checks for a conflict without doing any unit propagation.
        """
        if not self.unit_propagation:
            return not any(
                all(self.value(literal) is False for literal in clause)  # clause is unsatisfied
                for clause in self.cnf
            )
        else:
            while (len(self.propagation_queue) > 0):
                lit = self.propagation_queue.popleft()
                for clause in self.clauses_containing[lit]:
                    if not self.unit_propagate_clause(clause):
                        return False
            return True

    def solve(self) -> Union[str, List[Lit]]:
        """Return a satisfying assignment to the stored CNF, or return `'UNSAT'` if none exists."""
        for clause in self.cnf:
            if len(clause) == 1 and not self.assume(clause[0]):
                return "UNSAT"
        if not self.unit_propagate():
            return "UNSAT"
        
        while (True):
            picked_var = self.pick_variable()
            if picked_var is None:
                return [lit if ((val is True) or (val is None)) else -1 * lit
                        for lit, val in enumerate(self.assignment_stack[-1]) 
                        if ((lit != 0))]
            
            # decide to set var true
            self.decide(picked_var)

            # checks for unit propagation conflicts and backtracks
            conflict = not self.unit_propagate()
            while (conflict):
                if (len(self.decision_stack) == 0):
                    return 'UNSAT'
                else:
                    self.backtrack_and_assume_negation()
                conflict = not self.unit_propagate()


# solver = PennSAT(n, cnf)
# print(solver.solve())
# solver.solve()
