/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.convolutional;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.convolutional.Convolution;
import ai.djl.util.Preconditions;

public class Conv2d
extends Convolution {
    private static final LayoutType[] EXPECTED_LAYOUT = new LayoutType[]{LayoutType.BATCH, LayoutType.CHANNEL, LayoutType.HEIGHT, LayoutType.WIDTH};
    private static final String STRING_LAYOUT = "NCHW";
    private static final int NUM_DIMENSIONS = 4;

    protected Conv2d(Builder builder) {
        super(builder);
    }

    @Override
    protected LayoutType[] getExpectedLayout() {
        return EXPECTED_LAYOUT;
    }

    @Override
    protected String getStringLayout() {
        return STRING_LAYOUT;
    }

    @Override
    protected int numDimensions() {
        return 4;
    }

    public static NDList conv2d(NDArray input, NDArray weight) {
        return Conv2d.conv2d(input, weight, null, new Shape(1L, 1L), new Shape(0L, 0L), new Shape(1L, 1L));
    }

    public static NDList conv2d(NDArray input, NDArray weight, NDArray bias) {
        return Conv2d.conv2d(input, weight, bias, new Shape(1L, 1L), new Shape(0L, 0L), new Shape(1L, 1L));
    }

    public static NDList conv2d(NDArray input, NDArray weight, NDArray bias, Shape stride) {
        return Conv2d.conv2d(input, weight, bias, stride, new Shape(0L, 0L), new Shape(1L, 1L));
    }

    public static NDList conv2d(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding) {
        return Conv2d.conv2d(input, weight, bias, stride, padding, new Shape(1L, 1L));
    }

    public static NDList conv2d(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape dilation) {
        return Conv2d.conv2d(input, weight, bias, stride, padding, dilation, 1);
    }

    public static NDList conv2d(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape dilation, int groups) {
        Preconditions.checkArgument(input.getShape().dimension() == 4 && weight.getShape().dimension() == 4, "the shape of input or weight doesn't match the conv2d");
        Preconditions.checkArgument(stride.dimension() == 2 && padding.dimension() == 2 && dilation.dimension() == 2, "the shape of stride or padding or dilation doesn't match the conv2d");
        return Convolution.convolution(input, weight, bias, stride, padding, dilation, groups);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder
    extends Convolution.ConvolutionBuilder<Builder> {
        protected Builder() {
            this.stride = new Shape(1L, 1L);
            this.padding = new Shape(0L, 0L);
            this.dilation = new Shape(1L, 1L);
        }

        @Override
        protected Builder self() {
            return this;
        }

        public Conv2d build() {
            this.validate();
            return new Conv2d(this);
        }
    }
}

