/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.ProcessNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.AggregationDecorrelation;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.Util;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AssignUniqueId;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanNodeDecorrelator;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture;
import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures;
import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern;
import org.apache.tsfile.read.common.type.BooleanType;
import org.apache.tsfile.read.common.type.LongType;
import org.apache.tsfile.read.common.type.Type;

public class TransformCorrelatedGlobalAggregationWithoutProjection
implements Rule<CorrelatedJoinNode> {
    private static final Capture<AggregationNode> AGGREGATION = Capture.newCapture();
    private static final Capture<PlanNode> SOURCE = Capture.newCapture();
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo(BooleanLiteral.TRUE_LITERAL)).with(Patterns.CorrelatedJoin.subquery().matching(Patterns.aggregation().with(Pattern.empty(Patterns.Aggregation.groupingColumns())).with(Patterns.source().capturedAs(SOURCE)).capturedAs(AGGREGATION)));
    private final PlannerContext plannerContext;

    public TransformCorrelatedGlobalAggregationWithoutProjection(PlannerContext plannerContext) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    @Override
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        Preconditions.checkArgument((correlatedJoinNode.getJoinType() == JoinNode.JoinType.INNER || correlatedJoinNode.getJoinType() == JoinNode.JoinType.LEFT ? 1 : 0) != 0, (String)"unexpected correlated join type: %s", (Object)((Object)correlatedJoinNode.getJoinType()));
        PlanNode source = captures.get(SOURCE);
        AggregationNode distinct = null;
        PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(this.plannerContext, context.getSymbolAllocator(), context.getLookup());
        Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
        if (!decorrelatedSource.isPresent()) {
            if (AggregationDecorrelation.isDistinctOperator(source)) {
                distinct = (AggregationNode)source;
                source = distinct.getChild();
                decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
            }
            if (!decorrelatedSource.isPresent()) {
                return Rule.Result.empty();
            }
        }
        source = decorrelatedSource.get().getNode();
        Symbol nonNull = context.getSymbolAllocator().newSymbol("non_null", (Type)BooleanType.getInstance());
        source = new ProjectNode(context.getIdAllocator().genPlanNodeId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).put(nonNull, BooleanLiteral.TRUE_LITERAL).build());
        AssignUniqueId inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().genPlanNodeId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", (Type)LongType.getInstance()));
        JoinNode join = new JoinNode(context.getIdAllocator().genPlanNodeId(), JoinNode.JoinType.LEFT, inputWithUniqueId, source, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), ((PlanNode)inputWithUniqueId).getOutputSymbols(), source.getOutputSymbols(), decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty());
        ProcessNode root = join;
        if (distinct != null) {
            root = AggregationDecorrelation.restoreDistinctAggregation(distinct, join, (List<Symbol>)ImmutableList.builder().addAll(join.getLeftOutputSymbols()).add((Object)nonNull).addAll(distinct.getGroupingKeys()).build());
        }
        AggregationNode globalAggregation = captures.get(AGGREGATION);
        ImmutableMap.Builder masks = ImmutableMap.builder();
        Assignments.Builder assignmentsBuilder = Assignments.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : globalAggregation.getAggregations().entrySet()) {
            AggregationNode.Aggregation aggregation = entry.getValue();
            if (aggregation.getMask().isPresent()) {
                Symbol newMask = context.getSymbolAllocator().newSymbol("mask", (Type)BooleanType.getInstance());
                Expression expression = IrUtils.and(aggregation.getMask().get().toSymbolReference(), nonNull.toSymbolReference());
                assignmentsBuilder.put(newMask, expression);
                masks.put((Object)entry.getKey(), (Object)newMask);
                continue;
            }
            masks.put((Object)entry.getKey(), (Object)nonNull);
        }
        Assignments maskAssignments = assignmentsBuilder.build();
        if (!maskAssignments.isEmpty()) {
            root = new ProjectNode(context.getIdAllocator().genPlanNodeId(), root, Assignments.builder().putIdentities(root.getOutputSymbols()).putAll(maskAssignments).build());
        }
        globalAggregation = new AggregationNode(globalAggregation.getPlanNodeId(), root, AggregationDecorrelation.rewriteWithMasks(globalAggregation.getAggregations(), (Map<Symbol, Symbol>)masks.buildOrThrow()), AggregationNode.singleGroupingSet((List<Symbol>)ImmutableList.builder().addAll(join.getLeftOutputSymbols()).addAll(globalAggregation.getGroupingKeys()).build()), (List<Symbol>)ImmutableList.of(), globalAggregation.getStep(), Optional.empty(), Optional.empty());
        Optional<PlanNode> project = Util.restrictOutputs(context.getIdAllocator(), globalAggregation, (Set<Symbol>)ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols()));
        return Rule.Result.ofPlanNode(project.orElse(globalAggregation));
    }
}

