#ifndef IDEEP_OPERATORS_LRN_HPP
#define IDEEP_OPERATORS_LRN_HPP

namespace ideep {

struct lrn_forward : public dnnl::lrn_forward {
  using super = dnnl::lrn_forward;

  static void compute(
      const tensor& src,
      tensor& dst,
      dim local_size,
      float alpha,
      float beta,
      float k = 1.0,
      algorithm aalgorithm = algorithm::lrn_across_channels,
      prop_kind aprop_kind = prop_kind::forward_training,
      const engine& aengine = engine::cpu_engine()) {
    // workaround: use src.get_desc() once issue intel/mkl-dnn#588 is resolved
    auto src_desc = src._get_unblocked_desc_if_4c_blocked();

    auto op_attr = dnnl::primitive_attr();
    op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);

    // auto src_desc = src.get_desc();
    auto pd = primitive_desc(
        aengine, aprop_kind, aalgorithm, src_desc, src_desc,
        local_size, alpha, beta, k, op_attr);

    auto expected_src = src.reorder_if_differ_in(pd.src_desc());
    dst.reinit_if_possible(pd.dst_desc());
    tensor scratchpad(pd.scratchpad_desc());

    exec_args args{
        {DNNL_ARG_SRC, expected_src},
        {DNNL_ARG_DST, dst},
        {DNNL_ARG_SCRATCHPAD, scratchpad}};

    bool with_workspace = aprop_kind == prop_kind::forward_training;
    if (with_workspace) {
      dst.init_workspace(pd.workspace_desc());
      args.insert({DNNL_ARG_WORKSPACE, dst.get_workspace()});
    }

    super(pd).execute(stream::default_stream(), args);
  }
};

struct lrn_backward : public dnnl::lrn_backward {
  using super = dnnl::lrn_backward;

  static void compute(
      const tensor& src,
      const tensor& diff_dst,
      const tensor& dst,
      tensor& diff_src,
      dim local_size,
      float alpha,
      float beta,
      float k = 1.0,
      algorithm aalgorithm = algorithm::lrn_across_channels,
      const engine& aengine = engine::cpu_engine()) {
    IDEEP_CHECK(!(check_isa_is_avx2_vnni_2() &&
                  utils::one_of(diff_dst.get_data_type(),
                                data_type::bf16, data_type::f16)),
                  "DNNL does not support bf16/f16 backward on the platform with avx2_vnni_2");
    // workaround: use src.get_desc() once issue intel/mkl-dnn#588 is resolved
    auto src_desc = src._get_unblocked_desc_if_4c_blocked();
    // auto src_desc = src.get_desc();
    auto forward_hints = lrn_forward::primitive_desc(
        aengine, prop_kind::forward_training, aalgorithm, src_desc, src_desc,
        local_size, alpha, beta, k);

    auto op_attr = dnnl::primitive_attr();
    op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);

    auto pd = primitive_desc(
        aengine, aalgorithm, diff_dst.get_desc(), diff_dst.get_desc(), src_desc,
        local_size, alpha, beta, k, forward_hints, op_attr);

    auto expected_diff_dst = diff_dst.reorder_if_differ_in(pd.diff_dst_desc());
    diff_src.reinit_if_possible(pd.diff_src_desc());
    tensor scratchpad(pd.scratchpad_desc());

    exec_args args{
        {DNNL_ARG_SRC, src},
        {DNNL_ARG_DIFF_DST, expected_diff_dst},
        {DNNL_ARG_DIFF_SRC, diff_src},
        {DNNL_ARG_SCRATCHPAD, scratchpad}};

    if (dst.has_workspace()) {
      args.insert({DNNL_ARG_WORKSPACE, dst.get_workspace()});
    }
    super(pd).execute(stream::default_stream(), args);
  }
};

} // namespace ideep

#endif