/**
 * Stores a collection of items on disk.
 */

package theorycrafter.utils

import java.io.*
import java.nio.channels.Channels


/**
 * The current version of the storage format.
 */
private const val STORAGE_FORMAT_VERSION: Int = 2


/**
 * The position in the file the metadata begins.
 */
private const val METADATA_POSITION = 0L


/**
 * The position in the file where the entry count is written.
 */
private const val ENTRY_COUNT_POSITION = METADATA_POSITION + 8


/**
 * The position in the file where the user metadata begins.
 */
private const val USER_METADATA_POSITION = ENTRY_COUNT_POSITION + 8


/**
 * Initializes the file with an empty list.
 */
private fun <M> initializeFile(
    file: RandomAccessFile,
    itemFormatVersion: Int,
    userMetadata: M,
    userMetadataSerializer: StoredCollection.Serializer<M>
) {
    file.writeMetadata(
        Metadata(
            storageFormatVersion = STORAGE_FORMAT_VERSION,
            itemFormatVersion = itemFormatVersion,
            entryCount = 0,
            userMetadataFormatVersion = userMetadataSerializer.itemFormatVersion
        )
    )

    check(file.filePointer == USER_METADATA_POSITION)
    userMetadataSerializer.writeItem(file, userMetadata)
}


/**
 * The metadata we store in the file.
 */
private class Metadata(


    /**
     * The version number of the storage format.
     */
    val storageFormatVersion: Int,


    /**
     * The version number of the item format.
     */
    val itemFormatVersion: Int,


    /**
     * The number of entries in the file.
     */
    val entryCount: Int,


    /**
     * The user metadata format version number.
     */
    val userMetadataFormatVersion: Int


)


/**
 * Writes the file metadata.
 */
private fun RandomAccessFile.writeMetadata(metadata: Metadata) {
    seek(METADATA_POSITION)
    writeInt(metadata.storageFormatVersion)
    writeInt(metadata.itemFormatVersion)

    check(this.filePointer == ENTRY_COUNT_POSITION)
    writeInt(metadata.entryCount)

    writeInt(metadata.userMetadataFormatVersion)
}


/**
 * Writes the amount of entries stored into the file.
 */
private fun RandomAccessFile.writeEntryCount(count: Int) {
    seek(ENTRY_COUNT_POSITION)
    writeInt(count)
}


/**
 * Reads the file metadata.
 */
private fun RandomAccessFile.readMetadata(): Metadata {
    seek(METADATA_POSITION)
    val storageFormatVersion = readInt()
    val itemFormatVersion = readInt()
    val entryCount = readInt()
    val userMetadataFormatVersion = if (storageFormatVersion >= 2) readInt() else 0
    return Metadata(
        storageFormatVersion = storageFormatVersion,
        itemFormatVersion = itemFormatVersion,
        entryCount = entryCount,
        userMetadataFormatVersion = userMetadataFormatVersion
    )
}


/**
 * Runs migrations on the given file from the given storage format version to the current version.
 */
private fun <M> RandomAccessFile.upgrade(
    fromVersion: Int,
    metadata: Metadata,
    initialUserMetadata: () -> M,
    userMetadataSerializer: StoredCollection.Serializer<M>
) {
    if (fromVersion == 1) {
        val serializedUserMetadata = ByteArrayOutputStream()
        userMetadataSerializer.writeItem(DataOutputStream(serializedUserMetadata), initialUserMetadata())
        val userMetadataSize = serializedUserMetadata.size()

        // Move the user data to the right position
        seek(64L)  // This is where the entries started in version 1
        val userData = Channels.newInputStream(this.channel).buffered().readAllBytes()
        seek(USER_METADATA_POSITION + userMetadataSize)
        Channels.newOutputStream(this.channel).buffered().let {
            it.write(userData)
            it.flush()
        }
        this.setLength(this.filePointer)  // Truncate, if needed

        // Write the metadata
        writeMetadata(metadata)

        // Write the user metadata
        Channels.newOutputStream(this.channel).buffered().let {
            serializedUserMetadata.writeTo(it)
            it.flush()
        }
    }
}


/**
 * A single entry in the file, wrapping the stored item.
 */
private class Entry<T>(


    /**
     * The item; `null` if the entry is deleted.
     */
    val item: T?,


    /**
     * The position of the entry in the file.
     */
    val position: Long,


    /**
     * The length of the entry, in bytes.
     */
    val length: Int


) {


    /**
     * Whether this entry is deleted.
     */
    val isDeleted
        get() = item == null


    /**
     * Returns a deleted [Entry] at the same position and with the same length as this one.
     */
    fun deleted(): Entry<T> = Entry(null, position, length)


}


/**
 * Reads a single entry; returns `null` if the entry has been marked as deleted.
 */
private fun <T> PositionDataInputStream.readEntry(
    serializer: StoredCollection.Serializer<T>,
    itemFormatVersion: Int
): Entry<T> {
    val position = this.position

    val isDeleted = this.readBoolean()
    val length = this.readInt()
    if (length < 0)
        throw CorruptedStoredCollection("Read negative item length")

    val entryStartPosition = this.position

    val entry: Entry<T> = if (isDeleted){
        this.skipNBytes(length.toLong())
        Entry(
            item = null,
            position = position,
            length = length
        )
    }
    else {
        val item = serializer.readItem(this, itemFormatVersion)
        Entry(
            item = item,
            position = position,
            length = length
        )
    }

    val readBytes = (this.position - entryStartPosition).toInt()
    if (readBytes != length)
        throw CorruptedStoredCollection("Entry ${entry.item} has length $length but read $readBytes bytes")

    return entry
}


/**
 * Writes a single entry.
 */
private fun <T> PositionDataOutputStream.writeEntry(
    item: T,
    serializer: StoredCollection.Serializer<T>,
    buffer: OutputBuffer
): Entry<T> {
    val position = this.position

    buffer.reset()
    serializer.writeItem(buffer, item)
    val length = buffer.size()

    this.writeBoolean(false)  // Non-deleted entry
    this.writeInt(length)  // Entry length
    buffer.writeTo(this)

    return Entry(
        item = item,
        position = position,
        length = length
    )
}


/**
 * Marks the entry at the current position as deleted.
 */
private fun markEntryDeleted(file: RandomAccessFile) {
    file.writeBoolean(true)  // Deleted entry
}


/**
 * A collection of items backed by storage on disk.
 *
 * Note that the current implementation loads the entire collection into memory, and synchronously flushes each change
 * to disk.
 *
 * IMPORTANT: This class is not thread-safe.
 *
 * The format of the file is:
 * - A fixed-size header with metadata:
 *     - (4 bytes) Storage format version, to allow upgrades.
 *     - (4 bytes) Item format version, to allow user data upgrades.
 *     - (4 bytes) Number of entries in the file.
 *     - (4 bytes) User metadata format version, to allow user metadata upgrades.
 *     - User metadata.
 * - A sequence of entries, where each entry is:
 *     - One byte specifying whether this entry has been deleted. `1` if deleted, `0` if not.
 *     - The size of the item data block on disk, in bytes.
 *     - The item data block. This contains either the item, or garbage if the entry has been deleted.
 *
 * When an item is added, its entry is always written at the end of the file.
 * When an item is removed, its entry is marked as deleted.
 * When the collection is loaded from the file, it will "defrag" it, overwriting the deleted entries non-deleted ones.
 */
class StoredCollection<M, T>(


    /**
     * The file where the items are to be stored.
     */
    storageFile: File,


    /**
     * The serializer for user metadata.
     *
     * Note that the length of the metadata for each format version must not change.
     * The item format is 0 when reading a stored collection from before user metadata was stored in it.
     */
    private val userMetadataSerializer: Serializer<M>,


    /**
     * Returns the initial user metadata, used when the collection is being created for the first time.
     */
    initialUserMetadata: () -> M,


    /**
     * When the format of user metadata changes, this function is called to upgrade the metadata itself.
     */
    userMetadataUpgrader: (metadata: M, version: Int, items: Sequence<T>) -> M,


    /**
     * The serializer for user items.
     */
    private val itemSerializer: Serializer<T>,


    /**
     * The number of bytes into which items will almost always fit.
     * This will be used as the initial value of an internal buffer when writing items.
     */
    typicalMaxItemSizeBytes: Int = 8192


) {


    /**
     * The file where we store the items.
     */
    private val file = RandomAccessFile(storageFile, "rw")


    /**
     * The user metadata.
     */
    var userMetadata: M
        private set


    /**
     * The position in the file where the entries begin.
     */
    private var dataPosition: Long


    /**
     * The list of entries, including deleted ones (their value is `null`).
     */
    private val entries: MutableList<Entry<T>>


    /**
     * A buffer we use when writing entries.
     */
    private val buffer = outputBuffer(typicalMaxItemSizeBytes)


    init {
        // New file
        if (file.length() == 0L) {
            initializeFile(
                file = file,
                itemFormatVersion = itemSerializer.itemFormatVersion,
                userMetadata = initialUserMetadata(),
                userMetadataSerializer = userMetadataSerializer
            )
        }

        // Read metadata and upgrade if needed
        val metadata = file.readMetadata()
        if (metadata.storageFormatVersion < STORAGE_FORMAT_VERSION) {
            println("Upgrading stored collection at $storageFile" +
                    " from version ${metadata.storageFormatVersion} to $STORAGE_FORMAT_VERSION")

            file.upgrade(
                fromVersion = metadata.storageFormatVersion,
                metadata = metadata,
                initialUserMetadata = initialUserMetadata,
                userMetadataSerializer = userMetadataSerializer
            )
        }

        userMetadata = userMetadataSerializer.readItem(file, metadata.userMetadataFormatVersion)
        dataPosition = file.filePointer

        // Read the entries
        val input = inputAt(dataPosition)
        entries = ArrayList(metadata.entryCount)
        for (i in 0 until metadata.entryCount)
            entries.add(input.readEntry(itemSerializer, metadata.itemFormatVersion))

        // If the item or metadata format changed, or any entries are marked as deleted, re-write everything
        val userMetadataFormatChanged = metadata.userMetadataFormatVersion != userMetadataSerializer.itemFormatVersion
        val itemFormatChanged = metadata.itemFormatVersion != itemSerializer.itemFormatVersion
        val entriesDeleted = entries.any(Entry<T>::isDeleted)
        if (userMetadataFormatChanged || itemFormatChanged || entriesDeleted) {
            // Update user metadata
            if (userMetadataFormatChanged) {
                userMetadata = userMetadataUpgrader(userMetadata, metadata.userMetadataFormatVersion, items())
                file.seek(USER_METADATA_POSITION)
                userMetadataSerializer.writeItem(file, userMetadata)
                dataPosition = file.filePointer
            }

            // Update entries
            val output = outputAt(dataPosition)
            var remainingEntryCount = 0
            for (index in entries.indices) {
                val item = entries[index].item ?: continue  // Skip deleted entries
                entries[remainingEntryCount++] = output.writeEntry(item, itemSerializer, buffer)
            }
            output.flush()

            file.setLength(file.filePointer)  // Truncate file
            entries.removeTailFrom(remainingEntryCount)  // Truncate entry list

            // Update item format version and/or entry count
            file.writeMetadata(
                Metadata(
                    storageFormatVersion = STORAGE_FORMAT_VERSION,
                    itemFormatVersion = itemSerializer.itemFormatVersion,
                    entryCount = entries.size,
                    userMetadataFormatVersion = userMetadataSerializer.itemFormatVersion
                )
            )

            file.fd.sync()
        }
    }


    /**
     * Updates the user metadata.
     */
    fun setUserMetadata(metadata: M) {
        userMetadata = metadata
        file.seek(USER_METADATA_POSITION)
        userMetadataSerializer.writeItem(file, metadata)
    }


    /**
     * Returns the sequence of stored items.
     */
    fun items(): Sequence<T> = entries.asSequence().map { it.item }.filterNotNull()


    /**
     * Adds a sequence of items to the collection.
     */
    @Suppress("MemberVisibilityCanBePrivate")
    fun add(items: Sequence<T>){
        val output = outputAt(file.length())
        for (item in items){
            val entry = output.writeEntry(item, itemSerializer, buffer)
            entries.add(entry)
        }
        output.flush()

        // Update the entry count on disk
        file.writeEntryCount(entries.size)
        file.fd.sync()
    }


    /**
     * Adds the given items to the collection.
     */
    fun add(vararg items: T) {
        add(items.asSequence())
    }


    /**
     * Removes up to [maxRemoveCount] items matching the given predicate from the collection.
     * Returns the actual number of items removed.
     */
    fun removeMatching(maxRemoveCount: Int = Int.MAX_VALUE, predicate: (T) -> Boolean): Int {
        // Find the matching indices
        val matchingIndices = sequence{
            for ((index, entry) in entries.withIndex()) {
                val item = entry.item ?: continue
                if (predicate(item))
                    yield(index)
            }
        }

        // Mark as deleted on disk and in entries
        var deletedCount = 0
        for (index in matchingIndices){
            if (deletedCount >= maxRemoveCount)
                break

            // Mark as deleted in entries
            val entry = entries[index]
            entries[index] = entry.deleted()

            // Mark as deleted on disk
            file.seek(dataPosition + entry.position)
            markEntryDeleted(file)

            ++deletedCount
        }

        // Trim the deleted entries at the end
        if (entries.isNotEmpty()){
            // Note that this works correctly even if lastNonDeletedIndex is -1
            val firstTailIndex = entries.indexOfLast { !it.isDeleted } + 1
            if (firstTailIndex < entries.size){
                val firstTailEntry = entries[firstTailIndex]
                file.setLength(dataPosition + firstTailEntry.position)  // Truncate file
                entries.removeTailFrom(firstTailIndex)  // Truncate entries
                file.writeEntryCount(entries.size)  // Update the entry count on disk
            }
        }

        file.fd.sync()

        return deletedCount
    }


    /**
     * Removes the given item from the collection. If the item appears multiple times in the collection, only one
     * instance will be removed.
     *
     * Returns whether the item was in the collection.
     */
    fun remove(item: T): Boolean{
        val count = removeMatching(1) {
            it == item
        }

        return count == 1
    }


    /**
     * Closes the underlying file. The collection will not be usable afterwards.
     */
    @Suppress("unused")
    fun close(){
        file.close()
    }


    /**
     * Returns a [PositionDataOutputStream] at the given position.
     */
    private fun outputAt(position: Long): PositionDataOutputStream{
        file.seek(position)
        return Channels.newOutputStream(file.channel).buffered().dataWithPosition(position - dataPosition)
    }


    /**
     * Returns a [PositionDataInputStream] at the given position.
     */
    private fun inputAt(@Suppress("SameParameterValue") position: Long): PositionDataInputStream {
        file.seek(position)
        return Channels.newInputStream(file.channel).buffered().dataWithPosition(file.filePointer - dataPosition)
    }


    /**
     * The interface for reading and writing items.
     */
    interface Serializer<T> {


        /**
         * The current item format version.
         * This must be constant throughout the [StoredCollection]'s lifetime.
         */
        val itemFormatVersion: Int


        /**
         * Reads an item from the given [InputStream].
         */
        fun readItem(input: DataInput, formatVersion: Int): T


        /**
         * Writes the given item into the given [DataOutput].
         */
        fun writeItem(output: DataOutput, item: T)


    }


}


/**
 * An input stream that also remembers the number of bytes read from it.
 */
private class PositionInputStream(


    /**
     * The input stream to read from.
     */
    inputStream: InputStream,


    /**
     * The initial position.
     */
    initialPosition: Long


) : FilterInputStream(inputStream) {


    /**
     * The position. For every byte read from this `InputStream`, the position is advanced by 1.
     */
    var position = initialPosition
        private set


    override fun read(): Int {
        return super.read().also { ++position }
    }


    override fun read(b: ByteArray): Int {
        return super.read(b).also { position += it }
    }


    override fun read(b: ByteArray, off: Int, len: Int): Int {
        return super.read(b, off, len).also { position += it }
    }


    override fun skip(n: Long): Long {
        return super.skip(n).also { position += it }
    }


}


/**
 * A [DataInputStream] that also remembers the number of bytes read from it.
 */
private class PositionDataInputStream(inputStream: PositionInputStream) : DataInputStream(inputStream) {


    /**
     * The position.
     */
    val position by inputStream::position


}


/**
 * Wraps the [InputStream] in a [PositionDataInputStream] with the given initial position.
 */
private fun InputStream.dataWithPosition(initialPosition: Long = 0) =
    PositionDataInputStream(PositionInputStream(this, initialPosition))


/**
 * An output stream that also remembers the number of bytes written into it.
 */
private class PositionOutputStream(


    /**
     * The underlying [OutputStream] to write into.
     */
    outputStream: OutputStream,


    /**
     * The initial position.
     */
    initialPosition: Long


) : FilterOutputStream(outputStream) {


    /**
     * The position. For every byte written into this `InputStream`, the position is advanced by 1.
     */
    var position = initialPosition
        private set


    override fun write(b: Int) {
        out.write(b)
        ++position
    }


    override fun write(b: ByteArray) {
        out.write(b)
        position += b.size
    }


    override fun write(b: ByteArray, off: Int, len: Int) {
        out.write(b, off, len)
        position += len
    }


}



/**
 * A [DataOutputStream] that also remembers the number of bytes written into it.
 */
private class PositionDataOutputStream(outputStream: PositionOutputStream) : DataOutputStream(outputStream) {


    /**
     * The position.
     */
    val position by outputStream::position


}


/**
 * Wraps the [OutputStream] in a [PositionDataOutputStream] with the given initial position.
 */
private fun OutputStream.dataWithPosition(initialPosition: Long = 0) =
    PositionDataOutputStream(PositionOutputStream(this, initialPosition))


/**
 * A [DataOutputStream] that writes into a [OutputBuffer] and provides some of its functionality.
 */
private class OutputBuffer(private val buffer: ByteArrayOutputStream): DataOutputStream(buffer){


    /**
     * Resets the buffer to start writing at the beginning of the array.
     */
    fun reset(){
        buffer.reset()
        written = 0
    }


    /**
     * Writes the contents of the buffer into the given [OutputStream].
     */
    fun writeTo(outputStream: OutputStream){
        buffer.writeTo(outputStream)
    }


}


/**
 * Returns an [OutputBuffer] with the given initial size.
 */
private fun outputBuffer(initialSize: Int) = OutputBuffer(ByteArrayOutputStream(initialSize))


/**
 * The exception thrown when the file is corrupted.
 */
class CorruptedStoredCollection(message: String): IOException(message)