Program Listing for File model.h¶
↰ Return to documentation for file (include/parpecommon/model.h
)
#ifndef PARPE_COMMON_MODEL_H
#define PARPE_COMMON_MODEL_H
#include <parpecommon/functions.h>
#include <vector>
#include <memory>
namespace parpe {
template <typename X>
class Model {
public:
virtual ~Model() = default;
virtual void evaluate(gsl::span<const double> parameters,
std::vector<X> const& features,
std::vector<double>& outputs) const;
virtual void evaluate(gsl::span<const double> parameters,
std::vector<X> const& features,
std::vector<double>& outputs, // here only one output per model!
std::vector<std::vector<double>>& outputGradients) const = 0;
};
class LinearModel : public Model<std::vector<double>>
{
public:
LinearModel() = default;
// From Model:
using Model::evaluate;
void evaluate(gsl::span<const double> parameters,
std::vector<std::vector<double>> const& features,
std::vector<double>& outputs, // here only one output per model!
std::vector<std::vector<double>>& outputGradients) const override;
};
class LinearModelMSE : public SummedGradientFunction<int>
{
public:
explicit LinearModelMSE(int numParameters)
:numParameters_(numParameters) {}
// SummedGradientFunction
FunctionEvaluationStatus evaluate(
gsl::span<const double> parameters,
int dataset,
double &fval,
gsl::span<double> gradient,
Logger *logger,
double *cpuTime) const override {
std::vector<int> dsets {dataset};
return evaluate(parameters, dsets , fval, gradient, logger, cpuTime);
}
FunctionEvaluationStatus evaluate(
gsl::span<const double> parameters,
std::vector<int> dataIndices,
double &fval,
gsl::span<double> gradient,
Logger *logger,
double *cpuTime) const override;
int numParameters() const override {return numParameters_;}
std::vector<std::string> getParameterIds() const override {
std::vector<std::string> ids(numParameters());
for(int i = 0; i < static_cast<int>(ids.size()); ++i)
ids[i] = std::string("p") + std::to_string(i);
return ids;
}
int numParameters_ = 0;
std::vector<std::vector<double>> datasets;
std::vector<double> labels;
LinearModel lm;
};
}
#endif