/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.norm;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.training.ParameterStore;
import ai.djl.translate.Batchifier;
import ai.djl.translate.StackBatchifier;
import ai.djl.util.PairList;

public class GhostBatchNorm
extends BatchNorm {
    private int virtualBatchSize;
    private Batchifier batchifier;

    protected GhostBatchNorm(Builder builder) {
        super(builder);
        this.virtualBatchSize = builder.virtualBatchSize;
        this.batchifier = new StackBatchifier();
    }

    @Override
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDList[] subBatches = this.split(inputs);
        for (int i = 0; i < subBatches.length; ++i) {
            subBatches[i] = super.forwardInternal(parameterStore, subBatches[i], training, params);
        }
        return this.batchify(subBatches);
    }

    protected NDList[] split(NDList list) {
        double batchSize = list.head().size(0);
        int countBatches = (int)Math.ceil(batchSize / (double)this.virtualBatchSize);
        return this.batchifier.split(list, countBatches, true);
    }

    protected NDList batchify(NDList[] subBatches) {
        NDList batch = this.batchifier.batchify(subBatches);
        return this.squeezeExtraDimensions(batch);
    }

    protected NDList squeezeExtraDimensions(NDList batch) {
        NDArray array = batch.singletonOrThrow().squeeze(0);
        batch.set(0, array);
        return batch;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder
    extends BatchNorm.BaseBuilder<Builder> {
        private int virtualBatchSize = 128;

        Builder() {
        }

        public Builder optVirtualBatchSize(int virtualBatchSize) {
            this.virtualBatchSize = virtualBatchSize;
            return this;
        }

        @Override
        public GhostBatchNorm build() {
            return new GhostBatchNorm(this);
        }

        @Override
        public Builder self() {
            return this;
        }
    }
}

