package eve.data

import androidx.compose.runtime.Stable


/**
 * A damage pattern is the amount of damage of each type.
 */
class DamagePattern(val damageByTypeOrdinal: DoubleArray) {


    /**
     * Returns the damage of the given type.
     */
    @Stable
    operator fun get(damageType: DamageType): Double = damageByTypeOrdinal[damageType.ordinal]


    /**
     * Returns the sum of the damage patterns.
     */
    @Stable
    operator fun plus(damagePattern: DamagePattern) = DamagePattern { this[it] + damagePattern[it] }


    /**
     * Returns a damage pattern with the damage of each type multiplied by the given factor.
     */
    @Stable
    operator fun times(factor: Int) = DamagePattern { this[it] * factor }


    /**
     * Normalizes the damage pattern. Adjust damages such that the sum is `1.0`, while ratios are preserved.
     */
    @Stable
    fun normalize(): DamagePattern {
        val sum = damageByTypeOrdinal.sum()
        if (sum == 1.0)
            return this

        return DamagePattern { this[it]/sum }
    }


    /**
     * Calls the given function on each damage type and its corresponding damage value.
     */
    @Stable
    inline fun forEach(action: (DamageType, Double) -> Unit) {
        for ((index, dps) in damageByTypeOrdinal.withIndex()) {
            action(DamageType.entries[index], dps)
        }
    }


    /**
     * The sum of damage values.
     */
    @Stable
    val totalDamage: Double
        get() = damageByTypeOrdinal.sum()


    /**
     * The damage type which of which this damage pattern has the most.
     */
    @Stable
    val primaryDamageType: DamageType
        get() = DamageType.entries
            .maxBy { this[it] }


    override fun equals(other: Any?): Boolean {
        if (this === other) return true
        if (other !is DamagePattern) return false

        return damageByTypeOrdinal.contentEquals(other.damageByTypeOrdinal)
    }


    override fun hashCode(): Int {
        return damageByTypeOrdinal.contentHashCode()
    }


    companion object {

        /**
         * The uniform damage pattern, where there is an equal amount of every damage type.
         */
        @JvmStatic
        val Uniform = DamagePattern(DoubleArray(4) { 0.25 })

        /**
         * The no damage pattern, where there is no damage at all.
         */
        @JvmStatic
        val None = DamagePattern(DoubleArray(4) { 0.0 })

    }

}


/**
 * Returns a [DamagePattern] with the given damage values.
 */
inline fun DamagePattern(crossinline damageOfType: (DamageType) -> Double): DamagePattern {
    return DamagePattern(DoubleArray(4) { damageOfType(DamageType.entries[it]) })
}


/**
 * Returns the sum damage pattern of the given items. `null` values are ignored.
 */
@Stable
fun <T> Iterable<T>.sumOfDamagePatterns(itemDamagePattern: (T) -> DamagePattern?): DamagePattern {
    val dpsByDamageTypeOrdinal = DoubleArray(DamageType.entries.size) { 0.0 }
    for (item in this) {
        val itemPattern = itemDamagePattern(item) ?: continue
        for ((index, dps) in itemPattern.damageByTypeOrdinal.withIndex()) {
            dpsByDamageTypeOrdinal[index] += dps
        }
    }
    return DamagePattern(dpsByDamageTypeOrdinal)
}


/**
 * Returns the sum damage pattern of the given patterns.
 */
@Stable
fun Iterable<DamagePattern>.sum() = sumOfDamagePatterns { it }


/**
 * Returns [DamagePattern.Uniform] if this pattern is `null` or all damage values are `0.0`, otherwise this pattern.
 */
fun DamagePattern?.uniformIfNullOrNone() =
    if (this == null || damageByTypeOrdinal.all { it == 0.0 }) DamagePattern.Uniform else this
