//////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2010, 2026 Contributors to the Eclipse Foundation
//
// See the NOTICE file(s) distributed with this work for additional
// information regarding copyright ownership.
//
// This program and the accompanying materials are made available
// under the terms of the MIT License which is available at
// https://opensource.org/licenses/MIT
//
// SPDX-License-Identifier: MIT
//////////////////////////////////////////////////////////////////////////////

package org.eclipse.escet.cif.bdd.conversion.bitvectors;

import java.math.BigInteger;

import org.eclipse.escet.cif.bdd.spec.CifBddDomain;
import org.eclipse.escet.common.java.Assert;

import com.github.javabdd.BDD;
import com.github.javabdd.BDDFactory;

/** Unsigned BDD bit vector. */
public class UnsignedBddBitVector extends BddBitVector<UnsignedBddBitVector, UnsignedBddBitVectorAndCarry> {
    /** The minimum length (in number of bits) of any {@link UnsignedBddBitVector}. */
    public static final int MINIMUM_LENGTH = 1;

    /**
     * Constructor for the {@link UnsignedBddBitVector} class.
     *
     * @param factory The BDD factory to use.
     * @param length The number of bits of the bit vector.
     * @throws IllegalArgumentException If the length is less than one.
     */
    private UnsignedBddBitVector(BDDFactory factory, int length) {
        super(factory, length);
    }

    @Override
    protected int getMinimumLength() {
        return MINIMUM_LENGTH;
    }

    /**
     * Returns the minimum length (in number of bits) needed to represent the given non-negative integer value as an
     * {@link UnsignedBddBitVector}.
     *
     * @param value The non-negative integer value.
     * @return The minimum length.
     */
    public static int getMinimumLength(int value) {
        Assert.check(value >= 0);
        int count = 0;
        while (value > 0) {
            count++;
            value = value >> 1;
        }
        return Math.max(1, count); // At least one bit.
    }

    @Override
    protected UnsignedBddBitVector createEmpty(int length) {
        return new UnsignedBddBitVector(this.factory, length);
    }

    /**
     * Creates an {@link UnsignedBddBitVector}. Initializes the bits of the bit vector to 'false'.
     *
     * @param factory The BDD factory to use.
     * @param length The number of bits of the bit vector.
     * @return The created bit vector.
     * @throws IllegalArgumentException If the length is less than one.
     */
    public static UnsignedBddBitVector create(BDDFactory factory, int length) {
        return create(factory, length, false);
    }

    /**
     * Creates an {@link UnsignedBddBitVector}. Initializes each bit of the bit vector to the given boolean value.
     *
     * @param factory The BDD factory to use.
     * @param length The number of bits of the bit vector.
     * @param value The value to use for each bit.
     * @return The created bit vector.
     * @throws IllegalArgumentException If the length is less than one.
     */
    public static UnsignedBddBitVector create(BDDFactory factory, int length, boolean value) {
        // Create.
        UnsignedBddBitVector vector = new UnsignedBddBitVector(factory, length);

        // Initialize.
        for (int i = 0; i < vector.bits.length; i++) {
            vector.bits[i] = value ? factory.one() : factory.zero();
        }

        // Return.
        return vector;
    }

    /**
     * Creates an {@link UnsignedBddBitVector} from an integer value. Initializes the bits of the bit vector to the
     * given integer value. Uses an as small as possible bit vector to represent the integer value.
     *
     * @param factory The BDD factory to use.
     * @param value The integer value to represent using a bit vector.
     * @return The created bit vector.
     * @throws IllegalArgumentException If the value is negative.
     */
    public static UnsignedBddBitVector createFromInt(BDDFactory factory, int value) {
        // Precondition check.
        if (value < 0) {
            throw new IllegalArgumentException("Value is negative.");
        }

        // Create.
        int length = getMinimumLength(value);
        return createFromInt(factory, length, value);
    }

    /**
     * Creates an {@link UnsignedBddBitVector} from an integer value. Initializes the bits of the bit vector to the
     * given integer value, creating a bit vector of the given length. If the requested length is larger than the needed
     * number of bits, the remaining/highest bits are set to 'false'.
     *
     * @param factory The BDD factory to use.
     * @param length The number of bits of the bit vector.
     * @param value The integer value to represent using a bit vector.
     * @return The created bit vector.
     * @throws IllegalArgumentException If the value is negative.
     * @throws IllegalArgumentException If the length is insufficient to store the given value.
     */
    public static UnsignedBddBitVector createFromInt(BDDFactory factory, int length, int value) {
        // Precondition checks.
        if (value < 0) {
            throw new IllegalArgumentException("Value is negative.");
        }
        if (length < getMinimumLength(value)) {
            throw new IllegalArgumentException("Length is insufficient.");
        }

        // Create.
        UnsignedBddBitVector vector = new UnsignedBddBitVector(factory, length);

        // Initialize.
        for (int i = 0; i < vector.bits.length; i++) {
            vector.bits[i] = ((value & 0x1) != 0) ? factory.one() : factory.zero();
            value >>= 1;
        }
        Assert.areEqual(value, 0);

        // Return.
        return vector;
    }

    /**
     * Creates an {@link UnsignedBddBitVector} from a {@link CifBddDomain}. Initializes the bits of the bit vector to
     * the variables of the given domain. The length of the bit vector is the number of variables in the given domain.
     *
     * @param domain The domain to use.
     * @return The created bit vector.
     */
    public static UnsignedBddBitVector createFromDomain(CifBddDomain domain) {
        // Create bit vector.
        int varCnt = domain.getVarCount();
        UnsignedBddBitVector vector = new UnsignedBddBitVector(domain.getFactory(), varCnt);

        // Initialize.
        int[] vars = domain.getVarIndices();
        for (int i = 0; i < vars.length; i++) {
            vector.bits[i] = vector.factory.ithVar(vars[i]);
        }

        // Return.
        return vector;
    }

    @Override
    public BigInteger getLower() {
        return BigInteger.ZERO;
    }

    @Override
    public int getLowerInt() {
        return 0;
    }

    @Override
    public BigInteger getUpper() {
        return BigInteger.TWO.pow(bits.length).subtract(BigInteger.ONE);
    }

    @Override
    public int getUpperInt() {
        return getUpper().intValueExact();
    }

    @Override
    public Integer getInt() {
        // Check for enough room to represent the value.
        if (bits.length > 31) {
            throw new IllegalStateException("More than 31 bits in vector.");
        }

        // Return value.
        Long value = getLong();
        return (value == null) ? null : (int)(long)value;
    }

    @Override
    public Long getLong() {
        // Check for enough room to represent the value.
        if (bits.length > 63) {
            throw new IllegalStateException("More than 63 bits in vector.");
        }

        // Get value.
        long value = 0;
        for (int bitIndex = bits.length - 1; bitIndex >= 0; bitIndex--) {
            if (bits[bitIndex].isOne()) {
                // Shift already-considered higher bits. Add current '1' bit.
                value = (value << 1) | 1;
            } else if (bits[bitIndex].isZero()) {
                // Shift already-considered higher bits. Nothing to add for '0' bit.
                value = (value << 1);
            } else {
                // Not a constant value.
                return null;
            }
        }
        return value;
    }

    @Override
    public void setInt(int value) {
        // Precondition checks.
        if (value < 0) {
            throw new IllegalArgumentException("Value is negative.");
        }
        if (bits.length < getMinimumLength(value)) {
            throw new IllegalArgumentException("Length is insufficient.");
        }

        // Set value.
        for (int i = 0; i < bits.length; i++) {
            bits[i].free();
            bits[i] = ((value & 0x1) != 0) ? factory.one() : factory.zero();
            value >>= 1;
        }
        Assert.areEqual(value, 0);
    }

    /**
     * {@inheritDoc}
     *
     * <p>
     * For unsigned bit vectors, the additional (most significant) bits are set to 'false'.
     * </p>
     *
     * @throws IllegalArgumentException If the new length is less than one.
     */
    @Override
    public void resize(int length) {
        // Optimization.
        if (length == bits.length) {
            return;
        }

        // Precondition check.
        if (length < 1) {
            throw new IllegalArgumentException("Length is less than one.");
        }

        // Allocate new bits.
        BDD[] newBits = new BDD[length];

        // Copy the common bits.
        int numberOfCommonBits = Math.min(bits.length, length);
        System.arraycopy(bits, 0, newBits, 0, numberOfCommonBits);

        // If new length is larger, set additional bits to 'false'.
        for (int i = numberOfCommonBits; i < length; i++) {
            newBits[i] = factory.zero();
        }

        // If new length is smaller, free dropped bits.
        for (int i = numberOfCommonBits; i < bits.length; i++) {
            bits[i].free();
        }

        // Replace the bits.
        bits = newBits;
    }

    @Override
    public UnsignedBddBitVector shrink() {
        // Get the minimum-required length of this vector. For an unsigned bit vector, all most-significant 'false' bits
        // can be dropped without changing the vector's value(s). We do ensure we keep the mimimum-required length of
        // the bit vector representation into account as well.
        int length = bits.length;
        while (length > MINIMUM_LENGTH && bits[length - 1].isZero()) {
            length--;
        }

        // Resize to the minimum-required length.
        resize(length);

        // Return this bit vector, for chaining.
        return this;
    }

    /**
     * {@inheritDoc}
     *
     * <p>
     * This operation is not supported for {@link UnsignedBddBitVector}.
     * </p>
     *
     * @throws UnsupportedOperationException Always thrown.
     */
    @Override
    public UnsignedBddBitVectorAndCarry negate() {
        throw new UnsupportedOperationException();
    }

    @Override
    public UnsignedBddBitVectorAndCarry abs() {
        return new UnsignedBddBitVectorAndCarry(copy(), factory.zero());
    }

    @Override
    public UnsignedBddBitVectorAndCarry add(UnsignedBddBitVector other) {
        // Precondition check.
        if (this.bits.length != other.bits.length) {
            throw new IllegalArgumentException("Different lengths.");
        }

        // Compute result.
        UnsignedBddBitVector rslt = new UnsignedBddBitVector(factory, bits.length);
        BDD carry = factory.zero();
        for (int i = 0; i < bits.length; i++) {
            // rslt[i] = this[i] ^ other[i] ^ carry
            rslt.bits[i] = this.bits[i].xor(other.bits[i]).xorWith(carry.id());

            // carry = (this[i] & other[i]) | (carry & (this[i] | other[i]))
            carry = this.bits[i].and(other.bits[i]).orWith(carry.andWith(this.bits[i].or(other.bits[i])));
        }
        return new UnsignedBddBitVectorAndCarry(rslt, carry);
    }

    @Override
    public UnsignedBddBitVector sign() {
        // Compute 'this = 0'.
        BDD isZero = factory.one();
        for (BDD bit: bits) {
            isZero = isZero.andWith(bit.not());
        }

        // Compute result vector 'if this = 0 then 0 else 1 end'.
        UnsignedBddBitVector zero = createFromInt(factory, 1, 0);
        UnsignedBddBitVector one = createFromInt(factory, 1, 1);
        UnsignedBddBitVector result = ifThenElse(isZero, zero, one);

        // Cleanup.
        isZero.free();
        zero.free();
        one.free();

        // Return the result.
        return result;
    }

    @Override
    public UnsignedBddBitVectorAndCarry subtract(UnsignedBddBitVector other) {
        // Precondition check.
        if (this.bits.length != other.bits.length) {
            throw new IllegalArgumentException("Different lengths.");
        }

        // Compute result.
        UnsignedBddBitVector rslt = new UnsignedBddBitVector(factory, bits.length);
        BDD carry = factory.zero();
        for (int i = 0; i < bits.length; i++) {
            // rslt[i] = this[i] ^ other[i] ^ carry
            rslt.bits[i] = this.bits[i].xor(other.bits[i]).xorWith(carry.id());

            // carry = (this[n] & other[n] & carry) | (!this[n] & (other[n] | carry))
            BDD tmp1 = other.bits[i].or(carry);
            BDD tmp2 = this.bits[i].apply(tmp1, BDDFactory.less);
            tmp1.free();
            carry = this.bits[i].and(other.bits[i]).andWith(carry).orWith(tmp2);
        }
        return new UnsignedBddBitVectorAndCarry(rslt, carry);
    }

    @Override
    public UnsignedBddBitVectorAndCarry multiply(UnsignedBddBitVector other) {
        // Precondition check.
        if (this.bits.length != other.bits.length) {
            throw new IllegalArgumentException("Different lengths.");
        }

        // Sign-extend both vectors.
        int length = this.bits.length;
        int doubleLength = length * 2;
        UnsignedBddBitVector left = this.copy();
        UnsignedBddBitVector right = other.copy();
        left.resize(doubleLength);
        right.resize(doubleLength);

        // Perform multiplication.
        UnsignedBddBitVector result = create(factory, doubleLength, false);

        for (int i = 0; i < doubleLength; i++) {
            UnsignedBddBitVectorAndCarry added = result.add(left);

            for (int j = 0; j < doubleLength; j++) {
                BDD bit = result.bits[j];
                result.bits[j] = right.bits[i].ite(added.vector.bits[j], bit);
                bit.free();
            }

            added.vector.free();
            added.carry.free();

            UnsignedBddBitVector oldLeft = left;
            left = left.shiftLeft(1, factory.zero());
            oldLeft.free();
        }

        left.free();
        right.free();

        // Compute overflow. The upper 'length' bits must all be 'false', for there to be no overflow. If any of the
        // upper 'length' bits is different from 'false', it indicates overflow.
        BDD overflow = factory.zero();
        for (int i = length; i < doubleLength; i++) {
            overflow = overflow.orWith(result.bits[i].id());
        }

        // Truncate to the right length.
        result.resize(length);

        // Return the result.
        return new UnsignedBddBitVectorAndCarry(result, overflow);
    }

    /**
     * {@inheritDoc}
     *
     * @throws IllegalArgumentException If the divisor is not positive.
     * @throws IllegalArgumentException If the divisor doesn't fit within this bit vector.
     */
    @Override
    public UnsignedBddBitVector div(int divisor) {
        return divmod(divisor, true);
    }

    /**
     * {@inheritDoc}
     *
     * @throws IllegalArgumentException If the divisor is not positive.
     * @throws IllegalArgumentException If the divisor doesn't fit within this bit vector.
     * @throws IllegalStateException If the highest bit of this vector is not 'false'.
     */
    @Override
    public UnsignedBddBitVector mod(int divisor) {
        return divmod(divisor, false);
    }

    /**
     * Computes the quotient ('div' result) or remainder ('mod' result) of dividing this vector (the dividend) by the
     * given value (the divisor). This operation returns a new bit vector. The bit vector on which the operation is
     * performed is neither modified nor {@link #free freed}.
     *
     * @param divisorValue The value by which to divide this bit vector.
     * @param isDiv Whether to compute and return the quotient/'div' ({@code true}) or remainder/'mod' ({@code false}).
     * @return The quotient ('div' result) or remainder ('mod' result).
     * @throws IllegalArgumentException If the divisor is not positive.
     * @throws IllegalArgumentException If the divisor doesn't fit within this bit vector.
     * @throws IllegalStateException If the remainder/'mod' is computed, and the highest bit of this bit vector is not
     *     'false'.
     */
    private UnsignedBddBitVector divmod(int divisorValue, boolean isDiv) {
        // Precondition checks.
        if (divisorValue <= 0) {
            throw new IllegalArgumentException("Divisor is not positive.");
        }
        if (bits.length < getMinimumLength(divisorValue)) {
            throw new IllegalArgumentException("Divisor doesn't fit.");
        }
        if (!isDiv && !bits[bits.length - 1].isZero()) {
            throw new IllegalStateException(
                    "Computing the remainder/'mod', and the highest bit of the dividend is not 'false'.");
        }

        // Create divisor vector.
        UnsignedBddBitVector divisor = createFromInt(factory, bits.length, divisorValue);

        // Create result vectors.
        UnsignedBddBitVector quotient = shiftLeft(1, factory.zero());
        UnsignedBddBitVector remainderZero = create(factory, bits.length);
        UnsignedBddBitVector remainder = remainderZero.shiftLeft(1, bits[bits.length - 1]);
        remainderZero.free();

        // Compute result.
        divModRecursive(divisor, quotient, remainder, bits.length);
        divisor.free();

        // Return requested result.
        if (isDiv) {
            remainder.free();
            return quotient;
        } else {
            quotient.free();
            UnsignedBddBitVector shiftedRemainder = remainder.shiftRight(1, factory.zero());
            remainder.free();
            return shiftedRemainder;
        }
    }

    /**
     * Computes the quotient ('div' result) and remainder ('mod' result) of dividing a bit vector (the dividend) by the
     * given other bit vector (the divisor).
     *
     * @param divisor The divisor bit vector. Is not modified.
     * @param quotient The quotient/'div' bit vector, as computed so far. Is modified in-place.
     * @param remainder The remainder/'mod' bit vector, as computed so far. Is modified in-place.
     * @param step The number of steps to perform.
     */
    private void divModRecursive(UnsignedBddBitVector divisor, UnsignedBddBitVector quotient,
            UnsignedBddBitVector remainder, int step)
    {
        int divLen = divisor.bits.length;
        BDD isSmaller = divisor.lessOrEqual(remainder);
        UnsignedBddBitVector newQuotient = quotient.shiftLeft(1, isSmaller);

        UnsignedBddBitVector sub = create(factory, divLen);
        for (int i = 0; i < divLen; i++) {
            sub.bits[i] = isSmaller.ite(divisor.bits[i], factory.zero());
        }

        UnsignedBddBitVectorAndCarry tmp = remainder.subtract(sub);
        UnsignedBddBitVector newRemainder = tmp.vector.shiftLeft(1, quotient.bits[divLen - 1]);

        if (step > 1) {
            divModRecursive(divisor, newQuotient, newRemainder, step - 1);
        }

        tmp.vector.free();
        tmp.carry.free();
        sub.free();
        isSmaller.free();

        quotient.replaceBy(newQuotient);
        remainder.replaceBy(newRemainder);
    }

    @Override
    public BDD lessThan(UnsignedBddBitVector other) {
        // Precondition check.
        if (this.bits.length != other.bits.length) {
            throw new IllegalArgumentException("Different lengths.");
        }

        // Compute result.
        BDD rslt = factory.zero();
        for (int i = 0; i < bits.length; i++) {
            // rslt = (!this[i] & other[i]) | biimp(this[i], other[i]) & rslt
            BDD lt = this.bits[i].apply(other.bits[i], BDDFactory.less);
            BDD eq = this.bits[i].biimp(other.bits[i]);
            rslt = lt.orWith(eq.andWith(rslt));
        }
        return rslt;
    }

    @Override
    public BDD lessOrEqual(UnsignedBddBitVector other) {
        // Precondition check.
        if (this.bits.length != other.bits.length) {
            throw new IllegalArgumentException("Different lengths.");
        }

        // Compute result.
        BDD rslt = factory.one();
        for (int i = 0; i < bits.length; i++) {
            // rslt = (!this[i] & other[i]) | biimp(this[i], other[i]) & rslt
            BDD lt = this.bits[i].apply(other.bits[i], BDDFactory.less);
            BDD eq = this.bits[i].biimp(other.bits[i]);
            rslt = lt.orWith(eq.andWith(rslt));
        }
        return rslt;
    }

    @Override
    public UnsignedBddBitVector min(UnsignedBddBitVector other) {
        // Compute 'this <= other'.
        BDD cmp = this.lessOrEqual(other);

        // Compute result: 'if this <= other: this else other end'.
        UnsignedBddBitVector result = ifThenElse(cmp, this, other);

        // Cleanup.
        cmp.free();

        // Return the result.
        return result;
    }

    @Override
    public UnsignedBddBitVector max(UnsignedBddBitVector other) {
        // Compute 'this >= other'.
        BDD cmp = this.greaterOrEqual(other);

        // Compute result: 'if this >= other: this else other end'.
        UnsignedBddBitVector result = ifThenElse(cmp, this, other);

        // Cleanup.
        cmp.free();

        // Return the result.
        return result;
    }
}
