// Copyright 2018-2019 H2O.AI, Inc. All Rights Reserved.
#include <Rcpp.h>
#include <stdio.h>
#include <stdlib.h>

#include "c_api.h"

static inline std::string mojo_data_type_to_r_type(MOJO_DataType type) {
  switch (type) {
    case MOJO_FLOAT32:
    case MOJO_FLOAT64:
      return "double";
    case MOJO_INT32:
    case MOJO_INT64:
      return "integer";
    case MOJO_STRING:
      return "character";
    default:
      return "unknown";
  }
}

class Mojo {
 public:
  Mojo(const std::string& filename, const std::string& tf_lib_prefix) {
    model_ = MOJO_NewModel(filename.c_str(), tf_lib_prefix.c_str());
    pipeline_ = MOJO_NewPipeline(model_, MOJO_Transform_Operations::PREDICT);
    frame_ = NULL;
    feature_names_.resize(model_->feature_count);
    feature_types_.resize(model_->feature_count);
    for (size_t k = 0; k < model_->feature_count; ++k) {
      feature_names_[k] = std::string(model_->feature_names[k]);
      feature_types_[k] = mojo_data_type_to_r_type(model_->feature_types[k]);
    }

    output_names_.resize(pipeline_->output_count);
    output_types_.resize(pipeline_->output_count);
    for (size_t k = 0; k < pipeline_->output_count; ++k) {
      output_names_[k] = std::string(pipeline_->output_names[k]);
      output_types_[k] = mojo_data_type_to_r_type(pipeline_->output_types[k]);
    }

    missing_values_.resize(model_->missing_values_count);
    for (size_t k = 0; k < model_->missing_values_count; ++k) {
      missing_values_[k] = std::string(model_->missing_values[k]);
    }

    uuid_ = std::string(model_->uuid);
  }

  Rcpp::DataFrame predict(const Rcpp::DataFrame& data) {
    if (!model_->is_valid) {
      return NULL;
    }

    size_t nrow = data.nrow();
    if (frame_ != NULL) {
      MOJO_DeleteFrame(frame_);
    }
    frame_ = MOJO_Pipeline_NewFrame(pipeline_, nrow);

    for (size_t k = 0; k < model_->feature_count; ++k) {
      SEXP input_k = data[feature_names_[k]];
      MOJO_DataType col_type = model_->feature_types[k];
      void* buffer = MOJO_Input_Data(pipeline_, frame_, k);

      switch (col_type) {
        case MOJO_FLOAT32: {
          std::vector<double> double_data =
              Rcpp::as<std::vector<double> >(input_k);
          std::vector<float> float_data(double_data.begin(), double_data.end());
          std::copy(float_data.begin(), float_data.end(), (float*)buffer);
          break;
        }
        case MOJO_FLOAT64: {
          std::vector<double> double_data =
              Rcpp::as<std::vector<double> >(input_k);
          std::copy(double_data.begin(), double_data.end(), (double*)buffer);
          break;
        }
        case MOJO_INT32: {
          std::vector<int> int_data = Rcpp::as<std::vector<int> >(input_k);
          std::copy(int_data.begin(), int_data.end(), (int*)buffer);
          break;
        }
        case MOJO_INT64: {
          std::vector<int> int_data = Rcpp::as<std::vector<int> >(input_k);
          std::vector<int64_t> int64_data(int_data.begin(), int_data.end());
          std::copy(int64_data.begin(), int64_data.end(), (int64_t*)buffer);
          break;
        }
        case MOJO_STRING: {
          std::vector<std::string> str_data =
              Rcpp::as<std::vector<std::string> >(input_k);
          for (size_t i = 0; i < nrow; ++i) {
            MOJO_Column_Write_Str(buffer, i, str_data[i].c_str());
          }
          break;
        }
        case MOJO_UNKNOWN:
        default: {
          Rcpp::Rcerr << "unknown data type found for feature "
                      << feature_names_[k] << std::endl;
          break;
        }
      }
    }

    MOJO_Transform(pipeline_, frame_, nrow, true);

    Rcpp::List res(output_names_.size());
    for (size_t k = 0; k < output_names_.size(); ++k) {
      void* buffer = MOJO_Output_Data(pipeline_, frame_, k);
      MOJO_DataType col_type = pipeline_->output_types[k];
      switch (col_type) {
        case MOJO_FLOAT32: {
          std::vector<double> d(nrow);
          float* data_k = (float*)buffer;
          for (size_t i = 0; i < d.size(); ++i) {
            d[i] = data_k[i];
          }
          res[k] = Rcpp::wrap(d);
          break;
        }
        case MOJO_FLOAT64: {
          std::vector<double> d(nrow);
          double* data_k = (double*)buffer;
          for (size_t i = 0; i < d.size(); ++i) {
            d[i] = data_k[i];
          }
          res[k] = Rcpp::wrap(d);
          break;
        }
        case MOJO_INT32: {
          std::vector<int32_t> d(nrow);
          int32_t* data_k = (int32_t*)buffer;
          for (size_t i = 0; i < d.size(); ++i) {
            d[i] = data_k[i];
          }
          res[k] = Rcpp::wrap(d);
          break;
        }
        case MOJO_INT64: {
          std::vector<int64_t> d(nrow);
          int64_t* data_k = (int64_t*)buffer;
          for (size_t i = 0; i < d.size(); ++i) {
            d[i] = data_k[i];
          }
          res[k] = Rcpp::wrap(d);
          break;
        }
        case MOJO_STRING: {
          SEXP out = PROTECT(Rf_allocVector(STRSXP, nrow));
          for (size_t i = 0; i < nrow; ++i) {
            const char* s = MOJO_Column_Read_Str(buffer, i);
            SET_STRING_ELT(out, i, Rf_mkChar(s));
          }
          UNPROTECT(1);
          res[k] = out;
          break;
        }
        case MOJO_UNKNOWN:
        default: {
          Rcpp::Rcerr << "unknown data type found for output column "
                      << output_names_[k] << std::endl;
          break;
        }
      }
    }

    res.attr("names") = output_names_;
    return res;
  }

  std::vector<std::string> missing_values() { return missing_values_; }

  std::vector<std::string> feature_names() { return feature_names_; }

  std::vector<std::string> feature_types() { return feature_types_; }

  std::vector<std::string> output_names() { return output_names_; }

  std::vector<std::string> output_types() { return output_types_; }

  bool is_valid() { return model_->is_valid; }

  int64_t created_time() { return model_->time_created; }

  std::string uuid() { return uuid_; }

  const double binomial_problem_default_threshold() {
    if (model_->problem == nullptr)
      Rcpp::stop("The mojo file doesn't contain problem's data.");
    if (!model_->problem->supervised_detail().has_binomial_problem_detail())
      Rcpp::stop("The mojo file doesn't contain binomial_problem.");
    return model_->problem->supervised_detail()
        .binomial_problem_detail()
        .default_threshold();
  }

  std::vector<std::string> binomial_problem_labels() {
    if (model_->problem == nullptr)
      Rcpp::stop("The mojo file doesn't contain problem's data.");
    if (!model_->problem->supervised_detail().has_binomial_problem_detail())
      Rcpp::stop("The mojo file doesn't contain binomial_problem.");
    // It's not possible to use strings from `problem` directly due to CXX11_ABI issue.
    // daimojo is compiled with -D_GLIBCXX_USE_CXX11_ABI=0 but client may use gcc>=5.
    // There're multiple options with their pros and cons, but adding redundant fields
    // look better then the others(multiple libdaimojo.so, switching to CXX11_ABI=1).
    auto val = std::vector<std::string>(model_->binomial_problem_labels_count);
    for (size_t k = 0; k < model_->binomial_problem_labels_count; ++k)
      val[k] = std::string(model_->binomial_problem_labels[k]);

    return val;
  }

  ~Mojo() {
    if (frame_ != NULL) {
      MOJO_DeleteFrame(frame_);
    }
    MOJO_DeletePipeline(pipeline_);
    MOJO_DeleteModel(model_);
  }

 private:
  MOJO_Frame* frame_;
  MOJO_Pipeline* pipeline_;
  MOJO_Model* model_;

  std::vector<std::string> feature_names_, output_names_;
  std::vector<std::string> feature_types_, output_types_;
  std::vector<std::string> missing_values_;
  std::string uuid_;
};

RCPP_MODULE(rcppmojo) {
  using namespace Rcpp;

  class_<Mojo>("rcppmojo")
      .constructor<std::string, std::string>()
      .method("is_valid", &Mojo::is_valid)
      .method("created_time", &Mojo::created_time)
      .method("missing_values", &Mojo::missing_values)
      .method("uuid", &Mojo::uuid)
      .method("feature_names", &Mojo::feature_names)
      .method("feature_types", &Mojo::feature_types)
      .method("output_names", &Mojo::output_names)
      .method("output_types", &Mojo::output_types)
      .method("predict", &Mojo::predict)
      .method("binomial_problem_default_threshold",
              &Mojo::binomial_problem_default_threshold)
      .method("binomial_problem_labels", &Mojo::binomial_problem_labels);
}
