/*
 * Decompiled with CFR 0.152.
 */
package com.oracle.labs.mlrg.olcut.config.protobuf;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.TextFormat;
import com.oracle.labs.mlrg.olcut.config.protobuf.protos.ListProvenanceProto;
import com.oracle.labs.mlrg.olcut.config.protobuf.protos.MapProvenanceProto;
import com.oracle.labs.mlrg.olcut.config.protobuf.protos.ObjectProvenanceProto;
import com.oracle.labs.mlrg.olcut.config.protobuf.protos.RootProvenanceProto;
import com.oracle.labs.mlrg.olcut.config.protobuf.protos.SimpleProvenanceProto;
import com.oracle.labs.mlrg.olcut.provenance.io.FlatMarshalledProvenance;
import com.oracle.labs.mlrg.olcut.provenance.io.ListMarshalledProvenance;
import com.oracle.labs.mlrg.olcut.provenance.io.MapMarshalledProvenance;
import com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance;
import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization;
import com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public final class ProtoProvenanceSerialization
implements ProvenanceSerialization {
    private static final Base64.Encoder base64Encoder = Base64.getEncoder();
    private static final Base64.Decoder base64Decoder = Base64.getDecoder();
    private final boolean textFormat;

    public ProtoProvenanceSerialization(boolean textFormat) {
        this.textFormat = textFormat;
    }

    public String getFileExtension() {
        return this.textFormat ? "pbtxt" : "pb";
    }

    public List<ObjectMarshalledProvenance> deserializeFromFile(Path path) throws IOException {
        try {
            RootProvenanceProto proto;
            InputStream is = Files.newInputStream(path, new OpenOption[0]);
            if (this.textFormat) {
                RootProvenanceProto.Builder protoBuilder = RootProvenanceProto.newBuilder();
                TextFormat.getParser().merge((Readable)new InputStreamReader(is, StandardCharsets.UTF_8), (Message.Builder)protoBuilder);
                proto = protoBuilder.build();
            } else {
                proto = RootProvenanceProto.parseFrom(is);
            }
            return this.deserializeFromProto(proto);
        }
        catch (InvalidProtocolBufferException | TextFormat.ParseException e) {
            throw new IllegalArgumentException("Failed to parse protobuf", e);
        }
    }

    public List<ObjectMarshalledProvenance> deserializeFromString(String input) {
        try {
            RootProvenanceProto proto;
            if (this.textFormat) {
                RootProvenanceProto.Builder protoBuilder = RootProvenanceProto.newBuilder();
                TextFormat.getParser().merge((CharSequence)input, (Message.Builder)protoBuilder);
                proto = protoBuilder.build();
            } else {
                byte[] bytes = base64Decoder.decode(input);
                proto = RootProvenanceProto.parseFrom(bytes);
            }
            return this.deserializeFromProto(proto);
        }
        catch (InvalidProtocolBufferException | TextFormat.ParseException e) {
            throw new IllegalArgumentException("Failed to parse protobuf", e);
        }
    }

    public List<ObjectMarshalledProvenance> deserializeFromProto(RootProvenanceProto proto) {
        int curIndex;
        int totalProtos = proto.getLmpCount() + proto.getMmpCount() + proto.getOmpCount() + proto.getSmpCount();
        Message[] messages = new Message[totalProtos];
        for (ObjectProvenanceProto objectProvenanceProto : proto.getOmpList()) {
            curIndex = objectProvenanceProto.getIndex();
            if (messages[curIndex] != null) {
                throw new IllegalArgumentException("Invalid protobuf found, index " + curIndex + " collided, found '" + objectProvenanceProto.toString() + " and " + messages[curIndex].toString());
            }
            messages[curIndex] = objectProvenanceProto;
        }
        for (SimpleProvenanceProto simpleProvenanceProto : proto.getSmpList()) {
            curIndex = simpleProvenanceProto.getIndex();
            if (messages[curIndex] != null) {
                throw new IllegalArgumentException("Invalid protobuf found, index " + curIndex + " collided, found '" + simpleProvenanceProto.toString() + " and " + messages[curIndex].toString());
            }
            messages[curIndex] = simpleProvenanceProto;
        }
        for (MapProvenanceProto mapProvenanceProto : proto.getMmpList()) {
            curIndex = mapProvenanceProto.getIndex();
            if (messages[curIndex] != null) {
                throw new IllegalArgumentException("Invalid protobuf found, index " + curIndex + " collided, found '" + mapProvenanceProto.toString() + " and " + messages[curIndex].toString());
            }
            messages[curIndex] = mapProvenanceProto;
        }
        for (ListProvenanceProto listProvenanceProto : proto.getLmpList()) {
            curIndex = listProvenanceProto.getIndex();
            if (messages[curIndex] != null) {
                throw new IllegalArgumentException("Invalid protobuf found, index " + curIndex + " collided, found '" + listProvenanceProto.toString() + " and " + messages[curIndex].toString());
            }
            messages[curIndex] = listProvenanceProto;
        }
        ArrayList<ObjectMarshalledProvenance> outputList = new ArrayList<ObjectMarshalledProvenance>();
        for (ObjectProvenanceProto p : proto.getOmpList()) {
            HashMap<String, FlatMarshalledProvenance> provMap = new HashMap<String, FlatMarshalledProvenance>();
            for (Map.Entry<String, Integer> e : p.getValuesMap().entrySet()) {
                FlatMarshalledProvenance fmp = ProtoProvenanceSerialization.dispatchMessage(messages, e.getValue());
                provMap.put(e.getKey(), fmp);
            }
            ObjectMarshalledProvenance omp = new ObjectMarshalledProvenance(p.getObjectName(), provMap, p.getObjectClassName(), p.getProvenanceClassName());
            outputList.add(omp);
        }
        return outputList;
    }

    private static FlatMarshalledProvenance dispatchMessage(Message[] messages, int index) {
        Message curMessage = messages[index];
        if (curMessage instanceof SimpleProvenanceProto) {
            return ProtoProvenanceSerialization.decodeSMP((SimpleProvenanceProto)curMessage);
        }
        if (curMessage instanceof ListProvenanceProto) {
            return ProtoProvenanceSerialization.decodeLMP(messages, (ListProvenanceProto)curMessage);
        }
        if (curMessage instanceof MapProvenanceProto) {
            return ProtoProvenanceSerialization.decodeMMP(messages, (MapProvenanceProto)curMessage);
        }
        throw new IllegalStateException("Invalid protobuf, a message index points to an ObjectMarshalledProvenance");
    }

    private static SimpleMarshalledProvenance decodeSMP(SimpleProvenanceProto proto) {
        return new SimpleMarshalledProvenance(proto.getKey(), proto.getValue(), proto.getProvenanceClassName(), proto.getIsReference(), proto.getAdditional());
    }

    private static ListMarshalledProvenance decodeLMP(Message[] messages, ListProvenanceProto proto) {
        ArrayList<FlatMarshalledProvenance> list = new ArrayList<FlatMarshalledProvenance>();
        for (Integer i : proto.getValuesList()) {
            list.add(ProtoProvenanceSerialization.dispatchMessage(messages, i));
        }
        return new ListMarshalledProvenance(list);
    }

    private static MapMarshalledProvenance decodeMMP(Message[] messages, MapProvenanceProto proto) {
        HashMap<String, FlatMarshalledProvenance> map = new HashMap<String, FlatMarshalledProvenance>();
        for (Map.Entry<String, Integer> e : proto.getValuesMap().entrySet()) {
            map.put(e.getKey(), ProtoProvenanceSerialization.dispatchMessage(messages, e.getValue()));
        }
        return new MapMarshalledProvenance(map);
    }

    public RootProvenanceProto serializeToProto(List<ObjectMarshalledProvenance> marshalledProvenances) {
        RootProvenanceProto.Builder builder = RootProvenanceProto.newBuilder();
        MutableLong counter = new MutableLong(0L);
        for (ObjectMarshalledProvenance omp : marshalledProvenances) {
            ProtoProvenanceSerialization.convertProvenance(builder, counter, omp);
        }
        return builder.build();
    }

    private static void convertProvenance(RootProvenanceProto.Builder builder, MutableLong counter, ObjectMarshalledProvenance omp) {
        ObjectProvenanceProto.Builder ompBuilder = ObjectProvenanceProto.newBuilder();
        ompBuilder.setIndex(counter.intValue());
        counter.increment();
        ompBuilder.setObjectName(omp.getName());
        ompBuilder.setObjectClassName(omp.getObjectClassName());
        ompBuilder.setProvenanceClassName(omp.getProvenanceClassName());
        for (Map.Entry e : omp.getMap().entrySet()) {
            int count = ProtoProvenanceSerialization.dispatchFMP(builder, counter, (FlatMarshalledProvenance)e.getValue());
            ompBuilder.putValues((String)e.getKey(), count);
        }
        builder.addOmp(ompBuilder.build());
    }

    private static int dispatchFMP(RootProvenanceProto.Builder builder, MutableLong counter, FlatMarshalledProvenance fmp) {
        if (fmp instanceof SimpleMarshalledProvenance) {
            SimpleMarshalledProvenance smp = (SimpleMarshalledProvenance)fmp;
            return ProtoProvenanceSerialization.encodeSMP(builder, counter, smp);
        }
        if (fmp instanceof ListMarshalledProvenance) {
            ListMarshalledProvenance lmp = (ListMarshalledProvenance)fmp;
            return ProtoProvenanceSerialization.encodeLMP(builder, counter, lmp);
        }
        if (fmp instanceof MapMarshalledProvenance) {
            MapMarshalledProvenance mmp = (MapMarshalledProvenance)fmp;
            return ProtoProvenanceSerialization.encodeMMP(builder, counter, mmp);
        }
        throw new RuntimeException("Should not reach here, unexpected FlatMarshalledProvenance subclass " + fmp.getClass());
    }

    private static int encodeSMP(RootProvenanceProto.Builder builder, MutableLong counter, SimpleMarshalledProvenance smp) {
        SimpleProvenanceProto.Builder smpBuilder = SimpleProvenanceProto.newBuilder();
        int curIndex = counter.intValue();
        smpBuilder.setIndex(curIndex);
        counter.increment();
        smpBuilder.setKey(smp.getKey());
        smpBuilder.setValue(smp.getValue());
        smpBuilder.setAdditional(smp.getAdditional());
        smpBuilder.setProvenanceClassName(smp.getProvenanceClassName());
        smpBuilder.setIsReference(smp.isReference());
        builder.addSmp(smpBuilder.build());
        return curIndex;
    }

    private static int encodeLMP(RootProvenanceProto.Builder builder, MutableLong counter, ListMarshalledProvenance lmp) {
        ListProvenanceProto.Builder lmpBuilder = ListProvenanceProto.newBuilder();
        int curIndex = counter.intValue();
        lmpBuilder.setIndex(curIndex);
        counter.increment();
        for (FlatMarshalledProvenance fmp : lmp) {
            lmpBuilder.addValues(ProtoProvenanceSerialization.dispatchFMP(builder, counter, fmp));
        }
        builder.addLmp(lmpBuilder.build());
        return curIndex;
    }

    private static int encodeMMP(RootProvenanceProto.Builder builder, MutableLong counter, MapMarshalledProvenance mmp) {
        MapProvenanceProto.Builder mmpBuilder = MapProvenanceProto.newBuilder();
        int curIndex = counter.intValue();
        mmpBuilder.setIndex(curIndex);
        counter.increment();
        for (Pair p : mmp) {
            mmpBuilder.putValues((String)p.getA(), ProtoProvenanceSerialization.dispatchFMP(builder, counter, (FlatMarshalledProvenance)p.getB()));
        }
        builder.addMmp(mmpBuilder.build());
        return curIndex;
    }

    public String serializeToString(List<ObjectMarshalledProvenance> marshalledProvenances) {
        RootProvenanceProto proto = this.serializeToProto(marshalledProvenances);
        if (this.textFormat) {
            return proto.toString();
        }
        return base64Encoder.encodeToString(proto.toByteArray());
    }

    public void serializeToFile(List<ObjectMarshalledProvenance> marshalledProvenances, Path path) throws IOException {
        RootProvenanceProto proto = this.serializeToProto(marshalledProvenances);
        if (this.textFormat) {
            try (PrintWriter writer = new PrintWriter(Files.newBufferedWriter(path, new OpenOption[0]));){
                writer.println(proto.toString());
            }
        }
        try (BufferedOutputStream bos = new BufferedOutputStream(Files.newOutputStream(path, new OpenOption[0]));){
            proto.writeTo(bos);
        }
    }
}

