/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.cp;

import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.BinaryFrameFrameCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryFrameMatrixCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryMatrixMatrixCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryMatrixScalarCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryScalarScalarCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryTensorTensorCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;

public abstract class BinaryCPInstruction
extends ComputationCPInstruction {
    protected BinaryCPInstruction(CPInstruction.CPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
        super(type, op, in1, in2, out, opcode, istr);
    }

    protected BinaryCPInstruction(CPInstruction.CPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
        super(type, op, in1, in2, in3, out, opcode, istr);
    }

    public static BinaryCPInstruction parseInstruction(String str) {
        CPOperand in1 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand in2 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        String opcode = BinaryCPInstruction.parseBinaryInstruction(str, in1, in2, out);
        if (in1.getDataType() != Types.DataType.FRAME && in2.getDataType() != Types.DataType.FRAME) {
            BinaryCPInstruction.checkOutputDataType(in1, in2, out);
        }
        Operator operator = InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2);
        if (in1.getDataType() == Types.DataType.SCALAR && in2.getDataType() == Types.DataType.SCALAR) {
            return new BinaryScalarScalarCPInstruction(operator, in1, in2, out, opcode, str);
        }
        if (in1.getDataType() == Types.DataType.MATRIX && in2.getDataType() == Types.DataType.MATRIX) {
            return new BinaryMatrixMatrixCPInstruction(operator, in1, in2, out, opcode, str);
        }
        if (in1.getDataType() == Types.DataType.TENSOR && in2.getDataType() == Types.DataType.TENSOR) {
            return new BinaryTensorTensorCPInstruction(operator, in1, in2, out, opcode, str);
        }
        if (in1.getDataType() == Types.DataType.FRAME && in2.getDataType() == Types.DataType.FRAME) {
            return new BinaryFrameFrameCPInstruction(operator, in1, in2, out, opcode, str);
        }
        if (in1.getDataType() == Types.DataType.FRAME && in2.getDataType() == Types.DataType.MATRIX) {
            return new BinaryFrameMatrixCPInstruction(operator, in1, in2, out, opcode, str);
        }
        return new BinaryMatrixScalarCPInstruction(operator, in1, in2, out, opcode, str);
    }

    protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
        InstructionUtils.checkNumFields(parts, 3, 4);
        String opcode = parts[0];
        in1.split(parts[1]);
        in2.split(parts[2]);
        out.split(parts[3]);
        return opcode;
    }

    protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
        InstructionUtils.checkNumFields(parts, 4);
        String opcode = parts[0];
        in1.split(parts[1]);
        in2.split(parts[2]);
        in3.split(parts[3]);
        out.split(parts[4]);
        return opcode;
    }

    @Override
    public Operator getOperator() {
        return this._optr;
    }

    protected static void checkOutputDataType(CPOperand in1, CPOperand in2, CPOperand out) {
        if ((in1.getDataType() == Types.DataType.MATRIX || in2.getDataType() == Types.DataType.MATRIX) && out.getDataType() != Types.DataType.MATRIX) {
            throw new DMLRuntimeException("Element-wise matrix operations between variables " + in1.getName() + " and " + in2.getName() + " must produce a matrix, which " + out.getName() + " is not");
        }
    }
}

