async_mqtt 4.1.0
Loading...
Searching...
No Matches
stream.hpp
1// Copyright Takatoshi Kondo 2022
2//
3// Distributed under the Boost Software License, Version 1.0.
4// (See accompanying file LICENSE_1_0.txt or copy at
5// http://www.boost.org/LICENSE_1_0.txt)
6
7#if !defined(ASYNC_MQTT_STREAM_HPP)
8#define ASYNC_MQTT_STREAM_HPP
9
10#include <iostream>
11
12#include <utility>
13#include <type_traits>
14
15#include <boost/system/error_code.hpp>
16#include <boost/asio/strand.hpp>
17#include <boost/asio/io_context.hpp>
18#include <boost/asio/compose.hpp>
19#include <boost/asio/bind_executor.hpp>
20#include <boost/asio/append.hpp>
21#include <boost/asio/consign.hpp>
22
23#if defined(ASYNC_MQTT_USE_WS)
24#include <boost/beast/websocket/stream.hpp>
25#endif // defined(ASYNC_MQTT_USE_WS)
26
27#include <async_mqtt/core/stream_traits.hpp>
28#include <async_mqtt/util/make_shared_helper.hpp>
29#include <async_mqtt/util/optional.hpp>
30#include <async_mqtt/util/static_vector.hpp>
31#include <async_mqtt/util/ioc_queue.hpp>
32#include <async_mqtt/buffer.hpp>
33#include <async_mqtt/constant.hpp>
34#include <async_mqtt/is_strand.hpp>
35#include <async_mqtt/ws_fixed_size_async_read.hpp>
36#include <async_mqtt/exception.hpp>
37#include <async_mqtt/tls.hpp>
38#include <async_mqtt/log.hpp>
39
40namespace async_mqtt {
41
42namespace as = boost::asio;
43namespace sys = boost::system;
44
45template <typename Stream>
46struct is_ws : public std::false_type {};
47
48#if defined(ASYNC_MQTT_USE_WS)
49template <typename NextLayer>
50struct is_ws<bs::websocket::stream<NextLayer>> : public std::true_type {};
51#endif // defined(ASYNC_MQTT_USE_WS)
52
53template <typename Stream>
54struct is_tls : public std::false_type {};
55
56#if defined(ASYNC_MQTT_USE_TLS)
57template <typename NextLayer>
58struct is_tls<tls::stream<NextLayer>> : public std::true_type {};
59#endif // defined(ASYNC_MQTT_USE_TLS)
60
61template <typename NextLayer>
62class stream : public std::enable_shared_from_this<stream<NextLayer>> {
63public:
64 using this_type = stream<NextLayer>;
65 using this_type_sp = std::shared_ptr<this_type>;
66 using next_layer_type = typename std::remove_reference<NextLayer>::type;
68 using raw_strand_type = as::strand<executor_type>;
69 using strand_type = as::strand<as::any_io_executor>;
70
71 template <typename T>
72 friend class make_shared_helper;
73
74 template <
75 typename T,
76 typename... Args,
77 std::enable_if_t<!std::is_same_v<std::decay_t<T>, this_type>>* = nullptr
78 >
79 static std::shared_ptr<this_type> create(T&& t, Args&&... args) {
80 return make_shared_helper<this_type>::make_shared(std::forward<T>(t), std::forward<Args>(args)...);
81 }
82
83 ~stream() {
84 ASYNC_MQTT_LOG("mqtt_impl", trace)
85 << ASYNC_MQTT_ADD_VALUE(address, this)
86 << "destroy";
87 }
88
89 stream(this_type&&) = delete;
90 stream(this_type const&) = delete;
91 this_type& operator=(this_type&&) = delete;
92 this_type& operator=(this_type const&) = delete;
93
94 auto const& next_layer() const {
95 return nl_;
96 }
97 auto& next_layer() {
98 return nl_;
99 }
100
101 auto const& lowest_layer() const {
102 return get_lowest_layer(nl_);
103 }
104 auto& lowest_layer() {
105 return get_lowest_layer(nl_);
106 }
107
108 auto get_executor() const {
109 return nl_.get_executor();
110 }
111 auto get_executor() {
112 return nl_.get_executor();
113 }
114
115 template <typename CompletionToken>
116 auto
117 read_packet(
118 CompletionToken&& token
119 ) {
120 return
121 as::async_compose<
122 CompletionToken,
123 void(error_code const&, buffer)
124 >(
125 read_packet_impl{
126 *this
127 },
128 token
129 );
130 }
131
132 template <typename Packet, typename CompletionToken>
133 auto
134 write_packet(
135 Packet packet,
136 CompletionToken&& token
137 ) {
138 return
139 as::async_compose<
140 CompletionToken,
141 void(error_code const&, std::size_t)
142 >(
143 write_packet_impl<Packet>{
144 *this,
145 std::make_shared<Packet>(force_move(packet))
146 },
147 token
148 );
149 }
150
151 strand_type const& strand() const {
152 return strand_;
153 }
154
155 strand_type& strand() {
156 return strand_;
157 }
158
159 raw_strand_type const& raw_strand() const {
160 return raw_strand_;
161 };
162
163 raw_strand_type& raw_strand() {
164 return raw_strand_;
165 };
166
167 bool in_strand() const {
168 return raw_strand().running_in_this_thread();
169 }
170
171 template<typename CompletionToken>
172 auto
173 close(CompletionToken&& token) {
174 return
175 as::async_compose<
176 CompletionToken,
177 void(error_code const&)
178 >(
179 close_impl{
180 *this
181 },
182 token
183 );
184 }
185
186private:
187
188 // constructor
189 template <
190 typename T,
191 typename... Args,
192 std::enable_if_t<!std::is_same_v<std::decay_t<T>, this_type>>* = nullptr
193 >
194 explicit
195 stream(T&& t, Args&&... args)
196 :nl_{std::forward<T>(t), std::forward<Args>(args)...}
197 {
198 if constexpr(is_ws<next_layer_type>::value) {
199 nl_.binary(true);
200 nl_.set_option(
201 bs::websocket::stream_base::decorator(
202 [](bs::websocket::request_type& req) {
203 req.set("Sec-WebSocket-Protocol", "mqtt");
204 }
205 )
206 );
207 }
208 }
209
210 struct read_packet_impl {
211 this_type& strm;
212 std::size_t received = 0;
213 std::uint32_t mul = 1;
214 std::uint32_t rl = 0;
215 shared_ptr_array spa = nullptr;
216 this_type_sp life_keeper = strm.shared_from_this();
217 enum { dispatch, header, remaining_length, complete } state = dispatch;
218
219 template <typename Self>
220 void operator()(
221 Self& self
222 ) {
223 switch (state) {
224 case dispatch: {
225 state = header;
226 auto& a_strm{strm};
227 as::dispatch(
228 as::bind_executor(
229 a_strm.raw_strand_,
230 force_move(self)
231 )
232 );
233 } break;
234 case header: {
235 BOOST_ASSERT(strm.in_strand());
236 // read fixed_header
237 auto address = &strm.header_remaining_length_buf_[received];
238 auto& a_strm{strm};
239 async_read(
240 a_strm.nl_,
241 as::buffer(address, 1),
242 as::bind_executor(
243 a_strm.raw_strand_,
244 force_move(self)
245 )
246 );
247 } break;
248 default:
249 BOOST_ASSERT(false);
250 break;
251 }
252 }
253
254 template <typename Self>
255 void operator()(
256 Self& self,
257 error_code const& ec,
258 std::size_t bytes_transferred
259 ) {
260 BOOST_ASSERT(strm.in_strand());
261 if (ec) {
262 self.complete(ec, buffer{});
263 return;
264 }
265
266 switch (state) {
267 case header:
268 BOOST_ASSERT(bytes_transferred == 1);
269 state = remaining_length;
270 ++received;
271 // read the first remaining_length
272 {
273 auto address = &strm.header_remaining_length_buf_[received];
274 auto& a_strm{strm};
275 async_read(
276 a_strm.nl_,
277 as::buffer(address, 1),
278 as::bind_executor(
279 a_strm.raw_strand_,
280 force_move(self)
281 )
282 );
283 }
284 break;
285 case remaining_length:
286 BOOST_ASSERT(bytes_transferred == 1);
287 ++received;
288 if (strm.header_remaining_length_buf_[received - 1] & 0b10000000) {
289 // remaining_length continues
290 if (received == 5) {
291 ASYNC_MQTT_LOG("mqtt_impl", warning)
292 << ASYNC_MQTT_ADD_VALUE(address, this)
293 << "out of size remaining length";
294 self.complete(
295 sys::errc::make_error_code(sys::errc::protocol_error),
296 buffer{}
297 );
298 return;
299 }
300 rl += (strm.header_remaining_length_buf_[received - 1] & 0b01111111) * mul;
301 mul *= 128;
302 auto address = &strm.header_remaining_length_buf_[received];
303 auto& a_strm{strm};
304 async_read(
305 a_strm.nl_,
306 as::buffer(address, 1),
307 as::bind_executor(
308 a_strm.raw_strand_,
309 force_move(self)
310 )
311 );
312 }
313 else {
314 // remaining_length end
315 rl += (strm.header_remaining_length_buf_[received - 1] & 0b01111111) * mul;
316
317 spa = make_shared_ptr_array(received + rl);
318 std::copy(
319 strm.header_remaining_length_buf_.data(),
320 strm.header_remaining_length_buf_.data() + received, spa.get()
321 );
322
323 if (rl == 0) {
324 auto ptr = spa.get();
325 self.complete(ec, buffer{ptr, ptr + received + rl, force_move(spa)});
326 return;
327 }
328 else {
329 state = complete;
330 auto address = &spa[std::ptrdiff_t(received)];
331 auto& a_strm{strm};
332 async_read(
333 a_strm.nl_,
334 as::buffer(address, rl),
335 as::bind_executor(
336 a_strm.raw_strand_,
337 force_move(self)
338 )
339 );
340 }
341 }
342 break;
343 case complete: {
344 auto ptr = spa.get();
345 self.complete(ec, buffer{ptr, ptr + received + rl, force_move(spa)});
346 } break;
347 default:
348 BOOST_ASSERT(false);
349 break;
350 }
351 }
352 };
353
354 template <typename Packet>
355 struct write_packet_impl {
356 this_type& strm;
357 std::shared_ptr<Packet> packet;
358 this_type_sp life_keeper = strm.shared_from_this();
359 enum { dispatch, post, write, complete } state = dispatch;
360
361 template <typename Self>
362 void operator()(
363 Self& self
364 ) {
365 switch (state) {
366 case dispatch: {
367 state = post;
368 auto& a_strm{strm};
369 as::dispatch(
370 as::bind_executor(
371 a_strm.raw_strand_,
372 force_move(self)
373 )
374 );
375 } break;
376 case post: {
377 BOOST_ASSERT(strm.in_strand());
378 state = write;
379 auto& a_strm{strm};
380 a_strm.queue_.post(
381 as::bind_executor(
382 a_strm.raw_strand_,
383 force_move(self)
384 )
385 );
386 } break;
387 case write: {
388 BOOST_ASSERT(strm.in_strand());
389 strm.queue_.start_work();
390 if (strm.lowest_layer().is_open()) {
391 state = complete;
392 auto& a_strm{strm};
393 auto cbs = packet->const_buffer_sequence();
394 async_write(
395 a_strm.nl_,
396 cbs,
397 as::bind_executor(
398 a_strm.raw_strand_,
399 force_move(self)
400 )
401 );
402 }
403 else {
404 state = complete;
405 auto& a_strm{strm};
406 as::dispatch(
407 as::bind_executor(
408 a_strm.raw_strand_,
409 as::append(
410 force_move(self),
411 errc::make_error_code(errc::connection_reset),
412 0
413 )
414 )
415 );
416 }
417 } break;
418 default:
419 BOOST_ASSERT(false);
420 break;
421 }
422 }
423
424 template <typename Self>
425 void operator()(
426 Self& self,
427 error_code const& ec,
428 std::size_t bytes_transferred
429 ) {
430 BOOST_ASSERT(strm.in_strand());
431 if (ec) {
432 strm.queue_.stop_work();
433 auto& a_strm{strm};
434 as::post(
435 as::bind_executor(
436 a_strm.raw_strand_,
437 [&a_strm,wp = a_strm.weak_from_this()] {
438 if (auto sp = wp.lock()) {
439 a_strm.queue_.poll_one();
440 }
441 }
442 )
443 );
444 self.complete(ec, bytes_transferred);
445 return;
446 }
447 switch (state) {
448 case complete: {
449 strm.queue_.stop_work();
450 auto& a_strm{strm};
451 as::post(
452 as::bind_executor(
453 a_strm.raw_strand_,
454 [&a_strm, wp = a_strm.weak_from_this()] {
455 if (auto sp = wp.lock()) {
456 a_strm.queue_.poll_one();
457 }
458 }
459 )
460 );
461 self.complete(ec, bytes_transferred);
462 } break;
463 default:
464 BOOST_ASSERT(false);
465 break;
466 }
467 }
468 };
469
470 struct close_impl {
471 this_type& strm;
472 enum {
473 dispatch,
474 close,
475 drop1,
476 complete
477 } state = dispatch;
478 this_type_sp life_keeper = strm.shared_from_this();
479
480 template <typename Self>
481 void operator()(
482 Self& self
483 ) {
484 BOOST_ASSERT(state == dispatch);
485 state = close;
486 auto& a_strm{strm};
487 as::dispatch(
488 as::bind_executor(
489 a_strm.raw_strand_,
490 as::append(
491 force_move(self),
492 error_code{},
493 std::ref(a_strm.nl_)
494 )
495 )
496 );
497 }
498
499#if defined(ASYNC_MQTT_USE_WS)
500 template <typename Self, typename Stream>
501 void operator()(
502 Self& self,
503 error_code const& ec,
504 std::size_t /*size*/,
505 std::reference_wrapper<Stream> stream
506 ) {
507 BOOST_ASSERT(strm.in_strand());
508 if constexpr(is_ws<Stream>::value) {
509 BOOST_ASSERT(state == complete);
510 if (ec) {
511 if (ec == bs::websocket::error::closed) {
512 ASYNC_MQTT_LOG("mqtt_impl", info)
513 << ASYNC_MQTT_ADD_VALUE(address, this)
514 << "ws async_read (for close) success";
515 }
516 else {
517 ASYNC_MQTT_LOG("mqtt_impl", info)
518 << ASYNC_MQTT_ADD_VALUE(address, this)
519 << "ws async_read (for close):" << ec.message();
520 }
521 state = close;
522 auto& a_strm{strm};
523 as::dispatch(
524 as::bind_executor(
525 a_strm.raw_strand_,
526 as::append(
527 force_move(self),
528 error_code{},
529 std::ref(stream.get().next_layer())
530 )
531 )
532 );
533 }
534 else {
535 auto& a_strm{strm};
536 auto buffer = std::make_shared<bs::flat_buffer>();
537 stream.get().async_read(
538 *buffer,
539 as::bind_executor(
540 a_strm.raw_strand_,
541 as::append(
542 as::consign(
543 force_move(self),
544 buffer
545 ),
546 force_move(stream)
547 )
548 )
549 );
550 }
551 }
552 }
553#endif // defined(ASYNC_MQTT_USE_WS)
554
555 template <typename Self, typename Stream>
556 void operator()(
557 Self& self,
558 error_code const& ec,
559 std::reference_wrapper<Stream> stream
560 ) {
561 BOOST_ASSERT(strm.in_strand());
562 switch (state) {
563 case close: {
564 if constexpr(is_ws<Stream>::value) {
565 if (stream.get().is_open()) {
566 state = drop1;
567 auto& a_strm{strm};
568 stream.get().async_close(
569 bs::websocket::close_code::none,
570 as::bind_executor(
571 a_strm.raw_strand_,
572 as::append(
573 force_move(self),
574 force_move(stream)
575 )
576 )
577 );
578 }
579 else {
580 state = close;
581 auto& a_strm{strm};
582 as::dispatch(
583 as::bind_executor(
584 a_strm.raw_strand_,
585 as::append(
586 force_move(self),
587 error_code{},
588 std::ref(stream.get().next_layer())
589 )
590 )
591 );
592 }
593 }
594 else if constexpr(is_tls<Stream>::value) {
595 auto& a_strm{strm};
596 ASYNC_MQTT_LOG("mqtt_impl", info)
597 << ASYNC_MQTT_ADD_VALUE(address, this)
598 << "TLS async_shutdown start with timeout";
599 auto tim = std::make_shared<as::steady_timer>(a_strm.raw_strand_, shutdown_timeout);
600 tim->async_wait(
601 as::bind_executor(
602 a_strm.raw_strand_,
603 [this, &next_layer = stream.get().next_layer()] (error_code const& ec) {
604 if (!ec) {
605 ASYNC_MQTT_LOG("mqtt_impl", info)
606 << ASYNC_MQTT_ADD_VALUE(address, this)
607 << "TLS async_shutdown timeout";
608 error_code ec;
609 next_layer.close(ec);
610 }
611 }
612 )
613 );
614 stream.get().async_shutdown(
615 as::bind_executor(
616 a_strm.raw_strand_,
617 as::append(
618 as::consign(
619 force_move(self),
620 tim
621 ),
622 std::ref(stream.get().next_layer())
623 )
624 )
625 );
626 }
627 else {
628 error_code ec;
629 if (stream.get().is_open()) {
630 ASYNC_MQTT_LOG("mqtt_impl", info)
631 << ASYNC_MQTT_ADD_VALUE(address, this)
632 << "TCP close";
633 stream.get().close(ec);
634 }
635 else {
636 ASYNC_MQTT_LOG("mqtt_impl", info)
637 << ASYNC_MQTT_ADD_VALUE(address, this)
638 << "TCP already closed";
639 }
640 self.complete(ec);
641 }
642 } break;
643 case drop1: {
644 if constexpr(is_ws<Stream>::value) {
645 if (ec) {
646 ASYNC_MQTT_LOG("mqtt_impl", info)
647 << ASYNC_MQTT_ADD_VALUE(address, this)
648 << "ws async_close:" << ec.message();
649 state = close;
650 auto& a_strm{strm};
651 as::dispatch(
652 as::bind_executor(
653 a_strm.raw_strand_,
654 as::append(
655 force_move(self),
656 error_code{},
657 std::ref(stream.get().next_layer())
658 )
659 )
660 );
661 return;
662 }
663 state = complete;
664 auto& a_strm{strm};
665 auto buffer = std::make_shared<bs::flat_buffer>();
666 stream.get().async_read(
667 *buffer,
668 as::bind_executor(
669 a_strm.raw_strand_,
670 as::append(
671 as::consign(
672 force_move(self),
673 buffer
674 ),
675 force_move(stream)
676 )
677 )
678 );
679 }
680 } break;
681 default:
682 BOOST_ASSERT(false);
683 break;
684 }
685 }
686 };
687
688private:
689 next_layer_type nl_;
690 raw_strand_type raw_strand_{nl_.get_executor()};
691 strand_type strand_{as::any_io_executor{raw_strand_}};
692 ioc_queue queue_;
693 static_vector<char, 5> header_remaining_length_buf_ = static_vector<char, 5>(5);
694};
695
696} // namespace async_mqtt
697
698#endif // ASYNC_MQTT_STREAM_HPP
Definition packet_variant.hpp:49