async_mqtt 5.0.0
Loading...
Searching...
No Matches
stream.hpp
1// Copyright Takatoshi Kondo 2022
2//
3// Distributed under the Boost Software License, Version 1.0.
4// (See accompanying file LICENSE_1_0.txt or copy at
5// http://www.boost.org/LICENSE_1_0.txt)
6
7#if !defined(ASYNC_MQTT_STREAM_HPP)
8#define ASYNC_MQTT_STREAM_HPP
9
10#include <iostream>
11
12#include <utility>
13#include <type_traits>
14
15#include <boost/system/error_code.hpp>
16#include <boost/asio/strand.hpp>
17#include <boost/asio/io_context.hpp>
18#include <boost/asio/compose.hpp>
19#include <boost/asio/bind_executor.hpp>
20#include <boost/asio/append.hpp>
21#include <boost/asio/consign.hpp>
22
23#if defined(ASYNC_MQTT_USE_WS)
24#include <boost/beast/websocket/stream.hpp>
25#endif // defined(ASYNC_MQTT_USE_WS)
26
27#include <async_mqtt/stream_traits.hpp>
28#include <async_mqtt/util/make_shared_helper.hpp>
29#include <async_mqtt/util/optional.hpp>
30#include <async_mqtt/util/static_vector.hpp>
31#include <async_mqtt/util/ioc_queue.hpp>
32#include <async_mqtt/buffer.hpp>
33#include <async_mqtt/constant.hpp>
34#include <async_mqtt/is_strand.hpp>
35#include <async_mqtt/exception.hpp>
36#include <async_mqtt/tls.hpp>
37#include <async_mqtt/log.hpp>
38
39namespace async_mqtt {
40
41namespace as = boost::asio;
42namespace sys = boost::system;
43
44template <typename Stream>
45struct is_ws : public std::false_type {};
46
47#if defined(ASYNC_MQTT_USE_WS)
48namespace bs = boost::beast;
49
50template <typename NextLayer>
51struct is_ws<bs::websocket::stream<NextLayer>> : public std::true_type {};
52
53template <
54 typename NextLayer,
55 typename ConstBufferSequence,
56 typename CompletionToken,
57 typename std::enable_if_t<
58 as::is_const_buffer_sequence<ConstBufferSequence>::value
59 >* = nullptr
60>
61auto
62async_write(
63 bs::websocket::stream<NextLayer>& stream,
64 ConstBufferSequence const& cbs,
65 CompletionToken&& token
66) {
67 return stream.async_write(cbs, std::forward<CompletionToken>(token));
68}
69
70#endif // defined(ASYNC_MQTT_USE_WS)
71
72template <typename Stream>
73struct is_tls : public std::false_type {};
74
75#if defined(ASYNC_MQTT_USE_TLS)
76template <typename NextLayer>
77struct is_tls<tls::stream<NextLayer>> : public std::true_type {};
78#endif // defined(ASYNC_MQTT_USE_TLS)
79
80template <typename NextLayer, template <typename> typename Strand = as::strand>
81class stream : public std::enable_shared_from_this<stream<NextLayer, Strand>> {
82public:
83 using this_type = stream<NextLayer, Strand>;
84 using this_type_sp = std::shared_ptr<this_type>;
85 using next_layer_type = typename std::remove_reference<NextLayer>::type;
87 using raw_strand_type = as::strand<executor_type>;
88 using strand_type = Strand<as::any_io_executor>;
89
90 template <typename T>
91 friend class make_shared_helper;
92
93 template <
94 typename T,
95 typename... Args,
96 std::enable_if_t<!std::is_same_v<std::decay_t<T>, this_type>>* = nullptr
97 >
98 static std::shared_ptr<this_type> create(T&& t, Args&&... args) {
99 return make_shared_helper<this_type>::make_shared(std::forward<T>(t), std::forward<Args>(args)...);
100 }
101
102 ~stream() {
103 ASYNC_MQTT_LOG("mqtt_impl", trace)
104 << ASYNC_MQTT_ADD_VALUE(address, this)
105 << "destroy";
106 }
107
108 stream(this_type&&) = delete;
109 stream(this_type const&) = delete;
110 this_type& operator=(this_type&&) = delete;
111 this_type& operator=(this_type const&) = delete;
112
113 auto const& next_layer() const {
114 return nl_;
115 }
116 auto& next_layer() {
117 return nl_;
118 }
119
120 auto const& lowest_layer() const {
121 return get_lowest_layer(nl_);
122 }
123 auto& lowest_layer() {
124 return get_lowest_layer(nl_);
125 }
126
127 auto get_executor() const {
128 return nl_.get_executor();
129 }
130 auto get_executor() {
131 return nl_.get_executor();
132 }
133
134 template <typename CompletionToken>
135 auto
136 read_packet(
137 CompletionToken&& token
138 ) {
139 return
140 as::async_compose<
141 CompletionToken,
142 void(error_code const&, buffer)
143 >(
144 read_packet_impl{
145 *this
146 },
147 token
148 );
149 }
150
151 template <typename Packet, typename CompletionToken>
152 auto
153 write_packet(
154 Packet packet,
155 CompletionToken&& token
156 ) {
157 return
158 as::async_compose<
159 CompletionToken,
160 void(error_code const&, std::size_t)
161 >(
162 write_packet_impl<Packet>{
163 *this,
164 std::make_shared<Packet>(force_move(packet))
165 },
166 token
167 );
168 }
169
170 strand_type const& strand() const {
171 return strand_;
172 }
173
174 strand_type& strand() {
175 return strand_;
176 }
177
178 raw_strand_type const& raw_strand() const {
179 return raw_strand_;
180 };
181
182 raw_strand_type& raw_strand() {
183 return raw_strand_;
184 };
185
186 bool in_strand() const {
187 return raw_strand().running_in_this_thread();
188 }
189
190 template<typename CompletionToken>
191 auto
192 close(CompletionToken&& token) {
193 return
194 as::async_compose<
195 CompletionToken,
196 void(error_code const&)
197 >(
198 close_impl{
199 *this
200 },
201 token
202 );
203 }
204
205 void set_bulk_write(bool val) {
206 bulk_write_ = val;
207 }
208
209private:
210
211 // constructor
212 template <
213 typename T,
214 typename... Args,
215 std::enable_if_t<!std::is_same_v<std::decay_t<T>, this_type>>* = nullptr
216 >
217 explicit
218 stream(T&& t, Args&&... args)
219 :nl_{std::forward<T>(t), std::forward<Args>(args)...}
220 {
221#if defined(ASYNC_MQTT_USE_WS)
222 if constexpr(is_ws<next_layer_type>::value) {
223 nl_.binary(true);
224 nl_.set_option(
225 bs::websocket::stream_base::decorator(
226 [](bs::websocket::request_type& req) {
227 req.set("Sec-WebSocket-Protocol", "mqtt");
228 }
229 )
230 );
231 }
232#endif // defined(ASYNC_MQTT_USE_WS)
233 }
234
235 struct read_packet_impl {
236 this_type& strm;
237 std::size_t received = 0;
238 std::uint32_t mul = 1;
239 std::uint32_t rl = 0;
240 shared_ptr_array spa = nullptr;
241 this_type_sp life_keeper = strm.shared_from_this();
242 enum { dispatch, header, remaining_length, complete } state = dispatch;
243
244 template <typename Self>
245 void operator()(
246 Self& self
247 ) {
248 switch (state) {
249 case dispatch: {
250 state = header;
251 auto& a_strm{strm};
252 as::dispatch(
253 as::bind_executor(
254 a_strm.raw_strand_,
255 force_move(self)
256 )
257 );
258 } break;
259 case header: {
260 BOOST_ASSERT(strm.in_strand());
261 // read fixed_header
262 auto address = &strm.header_remaining_length_buf_[received];
263 auto& a_strm{strm};
264 async_read(
265 a_strm.nl_,
266 as::buffer(address, 1),
267 as::bind_executor(
268 a_strm.raw_strand_,
269 force_move(self)
270 )
271 );
272 } break;
273 default:
274 BOOST_ASSERT(false);
275 break;
276 }
277 }
278
279 template <typename Self>
280 void operator()(
281 Self& self,
282 error_code const& ec,
283 std::size_t bytes_transferred
284 ) {
285 (void)bytes_transferred; // Ignore unused argument in release build
286
287 BOOST_ASSERT(strm.in_strand());
288 if (ec) {
289 self.complete(ec, buffer{});
290 return;
291 }
292
293 switch (state) {
294 case header:
295 BOOST_ASSERT(bytes_transferred == 1);
296 state = remaining_length;
297 ++received;
298 // read the first remaining_length
299 {
300 auto address = &strm.header_remaining_length_buf_[received];
301 auto& a_strm{strm};
302 async_read(
303 a_strm.nl_,
304 as::buffer(address, 1),
305 as::bind_executor(
306 a_strm.raw_strand_,
307 force_move(self)
308 )
309 );
310 }
311 break;
312 case remaining_length:
313 BOOST_ASSERT(bytes_transferred == 1);
314 ++received;
315 if (strm.header_remaining_length_buf_[received - 1] & 0b10000000) {
316 // remaining_length continues
317 if (received == 5) {
318 ASYNC_MQTT_LOG("mqtt_impl", warning)
319 << ASYNC_MQTT_ADD_VALUE(address, this)
320 << "out of size remaining length";
321 self.complete(
322 sys::errc::make_error_code(sys::errc::protocol_error),
323 buffer{}
324 );
325 return;
326 }
327 rl += (strm.header_remaining_length_buf_[received - 1] & 0b01111111) * mul;
328 mul *= 128;
329 auto address = &strm.header_remaining_length_buf_[received];
330 auto& a_strm{strm};
331 async_read(
332 a_strm.nl_,
333 as::buffer(address, 1),
334 as::bind_executor(
335 a_strm.raw_strand_,
336 force_move(self)
337 )
338 );
339 }
340 else {
341 // remaining_length end
342 rl += (strm.header_remaining_length_buf_[received - 1] & 0b01111111) * mul;
343
344 spa = make_shared_ptr_array(received + rl);
345 std::copy(
346 strm.header_remaining_length_buf_.data(),
347 strm.header_remaining_length_buf_.data() + received, spa.get()
348 );
349
350 if (rl == 0) {
351 auto ptr = spa.get();
352 self.complete(ec, buffer{ptr, ptr + received + rl, force_move(spa)});
353 return;
354 }
355 else {
356 state = complete;
357 auto address = &spa[std::ptrdiff_t(received)];
358 auto& a_strm{strm};
359 async_read(
360 a_strm.nl_,
361 as::buffer(address, rl),
362 as::bind_executor(
363 a_strm.raw_strand_,
364 force_move(self)
365 )
366 );
367 }
368 }
369 break;
370 case complete: {
371 auto ptr = spa.get();
372 self.complete(ec, buffer{ptr, ptr + received + rl, force_move(spa)});
373 } break;
374 default:
375 BOOST_ASSERT(false);
376 break;
377 }
378 }
379 };
380
381 template <typename Packet>
382 struct write_packet_impl {
383 this_type& strm;
384 std::shared_ptr<Packet> packet;
385 std::size_t size = packet->size();
386 this_type_sp life_keeper = strm.shared_from_this();
387 enum { dispatch, post, write, bulk_write, complete } state = dispatch;
388
389 template <typename Self>
390 void operator()(
391 Self& self
392 ) {
393 switch (state) {
394 case dispatch: {
395 state = post;
396 auto& a_strm{strm};
397 as::dispatch(
398 as::bind_executor(
399 a_strm.raw_strand_,
400 force_move(self)
401 )
402 );
403 } break;
404 case post: {
405 BOOST_ASSERT(strm.in_strand());
406 auto& a_strm{strm};
407 auto& a_packet{*packet};
408 if (!a_strm.bulk_write_ || a_strm.queue_.immediate_executable()) {
409 state = write;
410 }
411 else {
412 state = bulk_write;
413 auto cbs = a_packet.const_buffer_sequence();
414 std::copy(cbs.begin(), cbs.end(), std::back_inserter(a_strm.storing_cbs_));
415 }
416 a_strm.queue_.post(
417 as::bind_executor(
418 a_strm.raw_strand_,
419 force_move(self)
420 )
421 );
422 } break;
423 case write: {
424 BOOST_ASSERT(strm.in_strand());
425 strm.queue_.start_work();
426 if (strm.lowest_layer().is_open()) {
427 state = complete;
428 auto& a_strm{strm};
429 auto& a_packet{*packet};
430 async_write(
431 a_strm.nl_,
432 a_packet.const_buffer_sequence(),
433 as::bind_executor(
434 a_strm.raw_strand_,
435 force_move(self)
436 )
437 );
438 }
439 else {
440 state = complete;
441 auto& a_strm{strm};
442 as::dispatch(
443 as::bind_executor(
444 a_strm.raw_strand_,
445 as::append(
446 force_move(self),
447 errc::make_error_code(errc::connection_reset),
448 0
449 )
450 )
451 );
452 }
453 } break;
454 case bulk_write: {
455 BOOST_ASSERT(strm.in_strand());
456 strm.queue_.start_work();
457 if (strm.lowest_layer().is_open()) {
458 state = complete;
459 auto& a_strm{strm};
460 if (a_strm.storing_cbs_.empty()) {
461 auto& a_strm{strm};
462 auto& a_size{size};
463 as::dispatch(
464 as::bind_executor(
465 a_strm.raw_strand_,
466 as::append(
467 force_move(self),
468 errc::make_error_code(errc::success),
469 a_size
470 )
471 )
472 );
473 }
474 else {
475 a_strm.sending_cbs_ = force_move(a_strm.storing_cbs_);
476 async_write(
477 a_strm.nl_,
478 a_strm.sending_cbs_,
479 as::bind_executor(
480 a_strm.raw_strand_,
481 force_move(self)
482 )
483 );
484 }
485 }
486 else {
487 state = complete;
488 auto& a_strm{strm};
489 as::dispatch(
490 as::bind_executor(
491 a_strm.raw_strand_,
492 as::append(
493 force_move(self),
494 errc::make_error_code(errc::connection_reset),
495 0
496 )
497 )
498 );
499 }
500 } break;
501 default:
502 BOOST_ASSERT(false);
503 break;
504 }
505 }
506
507 template <typename Self>
508 void operator()(
509 Self& self,
510 error_code const& ec,
511 std::size_t bytes_transferred
512 ) {
513 BOOST_ASSERT(strm.in_strand());
514 if (ec) {
515 strm.queue_.stop_work();
516 auto& a_strm{strm};
517 as::post(
518 as::bind_executor(
519 a_strm.raw_strand_,
520 [&a_strm,wp = a_strm.weak_from_this()] {
521 if (auto sp = wp.lock()) {
522 a_strm.queue_.poll_one();
523 }
524 }
525 )
526 );
527 self.complete(ec, bytes_transferred);
528 return;
529 }
530 switch (state) {
531 case complete: {
532 strm.queue_.stop_work();
533 strm.sending_cbs_.clear();
534 auto& a_strm{strm};
535 as::post(
536 as::bind_executor(
537 a_strm.raw_strand_,
538 [&a_strm, wp = a_strm.weak_from_this()] {
539 if (auto sp = wp.lock()) {
540 a_strm.queue_.poll_one();
541 }
542 }
543 )
544 );
545 self.complete(ec, size);
546 } break;
547 default:
548 BOOST_ASSERT(false);
549 break;
550 }
551 }
552 };
553
554 struct close_impl {
555 this_type& strm;
556 enum {
557 dispatch,
558 close,
559 drop1,
560 complete
561 } state = dispatch;
562 this_type_sp life_keeper = strm.shared_from_this();
563
564 template <typename Self>
565 void operator()(
566 Self& self
567 ) {
568 BOOST_ASSERT(state == dispatch);
569 state = close;
570 auto& a_strm{strm};
571 as::dispatch(
572 as::bind_executor(
573 a_strm.raw_strand_,
574 as::append(
575 force_move(self),
576 error_code{},
577 std::ref(a_strm.nl_)
578 )
579 )
580 );
581 }
582
583#if defined(ASYNC_MQTT_USE_WS)
584 template <typename Self, typename Stream>
585 void operator()(
586 Self& self,
587 error_code const& ec,
588 std::size_t /*size*/,
589 std::reference_wrapper<Stream> stream
590 ) {
591 BOOST_ASSERT(strm.in_strand());
592 if constexpr(is_ws<Stream>::value) {
593 BOOST_ASSERT(state == complete);
594 if (ec) {
595 if (ec == bs::websocket::error::closed) {
596 ASYNC_MQTT_LOG("mqtt_impl", info)
597 << ASYNC_MQTT_ADD_VALUE(address, this)
598 << "ws async_read (for close) success";
599 }
600 else {
601 ASYNC_MQTT_LOG("mqtt_impl", info)
602 << ASYNC_MQTT_ADD_VALUE(address, this)
603 << "ws async_read (for close):" << ec.message();
604 }
605 state = close;
606 auto& a_strm{strm};
607 as::dispatch(
608 as::bind_executor(
609 a_strm.raw_strand_,
610 as::append(
611 force_move(self),
612 error_code{},
613 std::ref(stream.get().next_layer())
614 )
615 )
616 );
617 }
618 else {
619 auto& a_strm{strm};
620 auto buffer = std::make_shared<bs::flat_buffer>();
621 stream.get().async_read(
622 *buffer,
623 as::bind_executor(
624 a_strm.raw_strand_,
625 as::append(
626 as::consign(
627 force_move(self),
628 buffer
629 ),
630 force_move(stream)
631 )
632 )
633 );
634 }
635 }
636 }
637#endif // defined(ASYNC_MQTT_USE_WS)
638
639 template <typename Self, typename Stream>
640 void operator()(
641 Self& self,
642 error_code const& ec,
643 std::reference_wrapper<Stream> stream
644 ) {
645 BOOST_ASSERT(strm.in_strand());
646 switch (state) {
647 case close: {
648#if defined(ASYNC_MQTT_USE_WS)
649 if constexpr(is_ws<Stream>::value) {
650 if (stream.get().is_open()) {
651 state = drop1;
652 auto& a_strm{strm};
653 stream.get().async_close(
654 bs::websocket::close_code::none,
655 as::bind_executor(
656 a_strm.raw_strand_,
657 as::append(
658 force_move(self),
659 force_move(stream)
660 )
661 )
662 );
663 }
664 else {
665 state = close;
666 auto& a_strm{strm};
667 as::dispatch(
668 as::bind_executor(
669 a_strm.raw_strand_,
670 as::append(
671 force_move(self),
672 error_code{},
673 std::ref(stream.get().next_layer())
674 )
675 )
676 );
677 }
678 }
679 else
680#endif // defined(ASYNC_MQTT_USE_WS)
681 if constexpr(is_tls<Stream>::value) {
682 auto& a_strm{strm};
683 ASYNC_MQTT_LOG("mqtt_impl", info)
684 << ASYNC_MQTT_ADD_VALUE(address, this)
685 << "TLS async_shutdown start with timeout";
686 auto tim = std::make_shared<as::steady_timer>(a_strm.raw_strand_, shutdown_timeout);
687 tim->async_wait(
688 as::bind_executor(
689 a_strm.raw_strand_,
690 [this, &next_layer = stream.get().next_layer()] (error_code const& ec) {
691 if (!ec) {
692 ASYNC_MQTT_LOG("mqtt_impl", info)
693 << ASYNC_MQTT_ADD_VALUE(address, this)
694 << "TLS async_shutdown timeout";
695 error_code ec;
696 next_layer.close(ec);
697 }
698 }
699 )
700 );
701 stream.get().async_shutdown(
702 as::bind_executor(
703 a_strm.raw_strand_,
704 as::append(
705 as::consign(
706 force_move(self),
707 tim
708 ),
709 std::ref(stream.get().next_layer())
710 )
711 )
712 );
713 }
714 else {
715 error_code ec;
716 if (stream.get().is_open()) {
717 ASYNC_MQTT_LOG("mqtt_impl", info)
718 << ASYNC_MQTT_ADD_VALUE(address, this)
719 << "TCP close";
720 stream.get().close(ec);
721 }
722 else {
723 ASYNC_MQTT_LOG("mqtt_impl", info)
724 << ASYNC_MQTT_ADD_VALUE(address, this)
725 << "TCP already closed";
726 }
727 strm.storing_cbs_.clear();
728 strm.sending_cbs_.clear();
729 self.complete(ec);
730 }
731 } break;
732 case drop1: {
733#if defined(ASYNC_MQTT_USE_WS)
734 if constexpr(is_ws<Stream>::value) {
735 if (ec) {
736 ASYNC_MQTT_LOG("mqtt_impl", info)
737 << ASYNC_MQTT_ADD_VALUE(address, this)
738 << "ws async_close:" << ec.message();
739 state = close;
740 auto& a_strm{strm};
741 as::dispatch(
742 as::bind_executor(
743 a_strm.raw_strand_,
744 as::append(
745 force_move(self),
746 error_code{},
747 std::ref(stream.get().next_layer())
748 )
749 )
750 );
751 return;
752 }
753 state = complete;
754 auto& a_strm{strm};
755 auto buffer = std::make_shared<bs::flat_buffer>();
756 stream.get().async_read(
757 *buffer,
758 as::bind_executor(
759 a_strm.raw_strand_,
760 as::append(
761 as::consign(
762 force_move(self),
763 buffer
764 ),
765 force_move(stream)
766 )
767 )
768 );
769 }
770#else // defined(ASYNC_MQTT_USE_WS)
771 (void)ec;
772#endif // defined(ASYNC_MQTT_USE_WS)
773 } break;
774 default:
775 BOOST_ASSERT(false);
776 break;
777 }
778 }
779 };
780
781private:
782 next_layer_type nl_;
783 raw_strand_type raw_strand_{nl_.get_executor()};
784 strand_type strand_{as::any_io_executor{raw_strand_}};
785 ioc_queue queue_;
786 static_vector<char, 5> header_remaining_length_buf_ = static_vector<char, 5>(5);
787 std::vector<as::const_buffer> storing_cbs_;
788 std::vector<as::const_buffer> sending_cbs_;
789 bool bulk_write_ = false;
790};
791
792} // namespace async_mqtt
793
794#endif // ASYNC_MQTT_STREAM_HPP
Definition packet_variant.hpp:49