Program Listing for File model.h

Return to documentation for file (include/parpecommon/model.h)

#ifndef PARPE_COMMON_MODEL_H
#define PARPE_COMMON_MODEL_H

#include <parpecommon/functions.h>

#include <vector>
#include <memory>

namespace parpe {

template <typename X>
class Model {
public:
    virtual ~Model() = default;

    virtual void evaluate(gsl::span<const double> parameters,
                          std::vector<X> const& features,
                          std::vector<double>& outputs) const;

    virtual void evaluate(gsl::span<const double> parameters,
                          std::vector<X> const& features,
                          std::vector<double>& outputs, // here only one output per model!
                          std::vector<std::vector<double>>& outputGradients) const = 0;

};


class LinearModel : public Model<std::vector<double>>
{
public:
    LinearModel() = default;

    // From Model:
    using Model::evaluate;

    void evaluate(gsl::span<const double> parameters,
                  std::vector<std::vector<double>> const& features,
                  std::vector<double>& outputs, // here only one output per model!
                  std::vector<std::vector<double>>& outputGradients) const override;

};

class LinearModelMSE : public SummedGradientFunction<int>
{
public:
  explicit LinearModelMSE(int numParameters)
        :numParameters_(numParameters) {}

    // SummedGradientFunction
    FunctionEvaluationStatus evaluate(
            gsl::span<const double> parameters,
            int dataset,
            double &fval,
            gsl::span<double> gradient,
            Logger *logger,
            double *cpuTime) const override {
        std::vector<int> dsets {dataset};
        return evaluate(parameters, dsets , fval, gradient, logger, cpuTime);
    }

    FunctionEvaluationStatus evaluate(
            gsl::span<const double> parameters,
            std::vector<int> dataIndices,
            double &fval,
            gsl::span<double> gradient,
            Logger *logger,
            double *cpuTime) const override;

    int numParameters() const override {return numParameters_;}

    std::vector<std::string> getParameterIds() const override {
        std::vector<std::string> ids(numParameters());
        for(int i = 0; i < static_cast<int>(ids.size()); ++i)
            ids[i] = std::string("p") + std::to_string(i);
        return ids;
    }


    int numParameters_ = 0;
    std::vector<std::vector<double>> datasets;
    std::vector<double> labels;
    LinearModel lm;
};

}
#endif