Viam C++ SDK current
Loading...
Searching...
No Matches
mlmodel.hpp
1// Copyright 2023 Viam Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#pragma once
16
17#include <iosfwd>
18
19#include <boost/mpl/joint_view.hpp>
20#include <boost/mpl/list.hpp>
21#include <boost/mpl/transform_view.hpp>
22#include <boost/variant/variant.hpp>
23
24#if defined(__has_include) && (__has_include(<xtensor/containers/xadapt.hpp>))
25#include <xtensor/containers/xadapt.hpp>
26#else
27#include <xtensor/xadapt.hpp>
28#endif
29
30#include <viam/sdk/common/utils.hpp>
31#include <viam/sdk/services/service.hpp>
32
33namespace viam {
34namespace sdk {
35
42class MLModelService : public Service {
43 private:
44 template <typename T>
45 struct make_tensor_view_ {
46 using shape_t = std::vector<std::size_t>;
47
48 using xt_no_ownership_t = decltype(xt::no_ownership());
49
50 using type = decltype(xt::adapt(std::declval<const T*>(),
51 std::declval<std::size_t>(),
52 std::declval<xt_no_ownership_t>(),
53 std::declval<shape_t>()));
54 };
55
56 public:
57 API api() const override;
58
59 template <typename T>
60 using tensor_view = typename make_tensor_view_<T>::type;
61
62 template <typename T>
63 static tensor_view<T> make_tensor_view(const T* data,
64 std::size_t size,
65 typename tensor_view<T>::shape_type shape) {
66 return xt::adapt(std::move(data), std::move(size), xt::no_ownership(), std::move(shape));
67 }
68
69 // Now that we have a factory for our tensor view types, use mpl
70 // to produce a variant over tensor views over the primitive types
71 // we care about, which are the signed and unsigned fixed width
72 // integral types and the two floating point types.
73 using signed_integral_base_types =
74 boost::mpl::list<std::int8_t, std::int16_t, std::int32_t, std::int64_t>;
75
76 using unsigned_integral_base_types =
77 boost::mpl::transform_view<signed_integral_base_types,
78 std::make_unsigned<boost::mpl::placeholders::_1>>;
79
80 using integral_base_types =
81 boost::mpl::joint_view<signed_integral_base_types, unsigned_integral_base_types>;
82
83 using fp_base_types = boost::mpl::list<float, double>;
84
85 using base_types = boost::mpl::joint_view<integral_base_types, fp_base_types>;
86
87 using tensor_view_types =
88 boost::mpl::transform_view<base_types, make_tensor_view_<boost::mpl::placeholders::_1>>;
89
90 // Union the tensor views for the various base types.
91 using tensor_views = boost::make_variant_over<tensor_view_types>::type;
92
93 // Our parameters to and from the model come as named tensor_views.
94 using named_tensor_views = std::unordered_map<std::string, tensor_views>;
95
102 inline std::shared_ptr<named_tensor_views> infer(const named_tensor_views& inputs) {
103 return infer(inputs, {});
104 }
105
114 virtual std::shared_ptr<named_tensor_views> infer(const named_tensor_views& inputs,
115 const ProtoStruct& extra) = 0;
116
117 struct tensor_info {
118 struct file {
119 std::string name;
120 std::string description;
121
122 enum : std::uint8_t {
123 k_label_type_tensor_value = 0,
124 k_label_type_tensor_axis = 1,
125 } label_type;
126 };
127
128 std::string name;
129 std::string description;
130
131 enum class data_types : std::uint8_t {
132 k_int8 = 0,
133 k_uint8 = 1,
134 k_int16 = 2,
135 k_uint16 = 3,
136 k_int32 = 4,
137 k_uint32 = 5,
138 k_int64 = 6,
139 k_uint64 = 7,
140 k_float32 = 8,
141 k_float64 = 9,
142 } data_type;
143
144 std::vector<int> shape;
145 std::vector<file> associated_files;
146
147 ProtoStruct extra;
148
149 static boost::optional<data_types> string_to_data_type(const std::string& str);
150 static const char* data_type_to_string(data_types data_type);
151
152 static data_types tensor_views_to_data_type(const tensor_views& view);
153 };
154
155 struct metadata {
156 std::string name;
157 std::string type;
158 std::string description;
159 std::vector<tensor_info> inputs;
160 std::vector<tensor_info> outputs;
161 };
162
164 inline struct metadata metadata() {
165 return metadata({});
166 }
167
171 virtual struct metadata metadata(const ProtoStruct& extra) = 0;
172
173 protected:
174 explicit MLModelService(std::string name);
175};
176
177template <>
179 static API api();
180};
181
182std::ostream& operator<<(std::ostream&, MLModelService::tensor_info::data_types);
183
184} // namespace sdk
185} // namespace viam
Definition resource_api.hpp:21
Represents a machine trained learning model instance.
Definition mlmodel.hpp:42
API api() const override
Returns the API associated with a particular resource.
std::shared_ptr< named_tensor_views > infer(const named_tensor_views &inputs)
Runs the model against the input tensors and returns inference results as tensors.
Definition mlmodel.hpp:102
struct metadata metadata()
Returns metadata describing the inputs and outputs of the model.
Definition mlmodel.hpp:164
virtual std::shared_ptr< named_tensor_views > infer(const named_tensor_views &inputs, const ProtoStruct &extra)=0
Runs the model against the input tensors and returns inference results as tensors.
virtual std::string name() const
Return the resource's name.
Definition service.hpp:10
Definition resource_api.hpp:46
Definition mlmodel.hpp:155
Definition mlmodel.hpp:117