1 /* -*- Mode: C++; tab-width: 4; c-basic-offset: 4; indent-tabs-mode: nil -*- */
2 /*
3  *     Copyright 2015 Couchbase, Inc.
4  *
5  *   Licensed under the Apache License, Version 2.0 (the "License");
6  *   you may not use this file except in compliance with the License.
7  *   You may obtain a copy of the License at
8  *
9  *       http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *   Unless required by applicable law or agreed to in writing, software
12  *   distributed under the License is distributed on an "AS IS" BASIS,
13  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *   See the License for the specific language governing permissions and
15  *   limitations under the License.
16  */
17 #include "client_connection.h"
18 #include "client_mcbp_commands.h"
19 #include "frameinfo.h"
20 
21 #include <cbsasl/client.h>
22 #include <mcbp/mcbp.h>
23 #include <mcbp/protocol/framebuilder.h>
24 #include <memcached/protocol_binary.h>
25 #include <nlohmann/json.hpp>
26 #include <platform/compress.h>
27 #include <platform/dirutils.h>
28 #include <platform/socket.h>
29 #include <platform/strerror.h>
30 
31 #include <cerrno>
32 #include <functional>
33 #include <gsl/gsl>
34 #include <iostream>
35 #include <limits>
36 #include <memory>
37 #ifndef WIN32
38 #include <netdb.h>
39 #include <netinet/tcp.h> // For TCP_NODELAY etc
40 #endif
41 #include <sstream>
42 #include <stdexcept>
43 #include <string>
44 #include <system_error>
45 #include <thread>
46 
47 static const bool packet_dump = getenv("COUCHBASE_PACKET_DUMP") != nullptr;
48 
operator <<(::std::ostream& os, const DocumentInfo& info)49 ::std::ostream& operator<<(::std::ostream& os, const DocumentInfo& info) {
50     return os << "id:" << info.id << " flags:" << info.flags
51               << " exp:" << info.expiration
52               << " datatype:" << int(info.datatype) << " cas:" << info.cas;
53 }
54 
operator <<(::std::ostream& os, const Document& doc)55 ::std::ostream& operator<<(::std::ostream& os, const Document& doc) {
56     os << "info:" << doc.info << " value: [" << std::hex;
57     for (auto& v : doc.value) {
58         os << int(v) << " ";
59     }
60     return os << std::dec << "]";
61 }
62 
compress()63 void Document::compress() {
64     if (mcbp::datatype::is_snappy(protocol_binary_datatype_t(info.datatype))) {
65         throw std::invalid_argument(
66                 "Document::compress: Cannot compress already compressed "
67                 "document.");
68     }
69 
70     cb::compression::Buffer buf;
71     cb::compression::deflate(cb::compression::Algorithm::Snappy, value, buf);
72     value = {buf.data(), buf.size()};
73     info.datatype = cb::mcbp::Datatype(uint8_t(info.datatype) |
74                                        uint8_t(cb::mcbp::Datatype::Snappy));
75 }
76 
77 /////////////////////////////////////////////////////////////////////////
78 // Implementation of the MemcachedConnection class
79 /////////////////////////////////////////////////////////////////////////
MemcachedConnection(std::string host, in_port_t port, sa_family_t family, bool ssl)80 MemcachedConnection::MemcachedConnection(std::string host,
81                                          in_port_t port,
82                                          sa_family_t family,
83                                          bool ssl)
84     : host(std::move(host)), port(port), family(family), ssl(ssl) {
85     if (ssl) {
86         char* env = getenv("COUCHBASE_SSL_CLIENT_CERT_PATH");
87         if (env != nullptr) {
88             setSslCertFile(std::string{env} + "/client.pem");
89             setSslKeyFile(std::string{env} + "/client.key");
90         }
91     }
92 }
93 
~MemcachedConnection()94 MemcachedConnection::~MemcachedConnection() {
95     close();
96 }
97 
close()98 void MemcachedConnection::close() {
99     effective_features.clear();
100     if (ssl) {
101         if (bio != nullptr) {
102             BIO_free_all(bio);
103             bio = nullptr;
104         }
105         if (context != nullptr) {
106             SSL_CTX_free(context);
107             context = nullptr;
108         }
109     }
110 
111     if (sock != INVALID_SOCKET) {
112         cb::net::shutdown(sock, SHUT_RDWR);
113         cb::net::closesocket(sock);
114         sock = INVALID_SOCKET;
115     }
116 }
117 
try_connect_socket(struct addrinfo* next, const std::string& hostname, in_port_t port)118 SOCKET try_connect_socket(struct addrinfo* next,
119                           const std::string& hostname,
120                           in_port_t port) {
121     SOCKET sfd = cb::net::socket(
122             next->ai_family, next->ai_socktype, next->ai_protocol);
123     if (sfd == INVALID_SOCKET) {
124         throw std::system_error(cb::net::get_socket_error(),
125                                 std::system_category(),
126                                 "socket() failed (" + hostname + " " +
127                                         std::to_string(port) + ")");
128     }
129 
130 #ifdef WIN32
131     // BIO_new_socket pass the socket as an int, but it is a SOCKET on
132     // Windows.. On windows a socket is an unsigned value, and may
133     // get an overflow inside openssl (I don't know the exact width of
134     // the SOCKET, and how openssl use the value internally). This
135     // class is mostly used from the test framework so let's throw
136     // an exception instead and treat it like a test failure (to be
137     // on the safe side). We'll be refactoring to SCHANNEL in the
138     // future anyway.
139     if (sfd > std::numeric_limits<int>::max()) {
140         cb::net::closesocket(sfd);
141         throw std::runtime_error(
142                 "Socket value too big "
143                 "(may trigger behavior openssl)");
144     }
145 #endif
146 
147     // When running unit tests on our Windows CV system we somtimes
148     // see connect fail with WSAEADDRINUSE. For a client socket
149     // we don't bind the socket as that's implicit from calling
150     // connect. Mark the socket reusable so that the kernel may
151     // reuse the socket earlier
152     const int flag = 1;
153     cb::net::setsockopt(sfd,
154                         SOL_SOCKET,
155                         SO_REUSEADDR,
156                         reinterpret_cast<const void*>(&flag),
157                         sizeof(flag));
158 
159     // Try to set the nodelay mode on the socket (but ignore
160     // if we fail to do so..
161     cb::net::setsockopt(sfd,
162                         IPPROTO_TCP,
163                         TCP_NODELAY,
164                         reinterpret_cast<const void*>(&flag),
165                         sizeof(flag));
166 
167     if (cb::net::connect(sfd, next->ai_addr, next->ai_addrlen) == SOCKET_ERROR) {
168         auto error = cb::net::get_socket_error();
169         cb::net::closesocket(sfd);
170 #ifdef WIN32
171         WSASetLastError(error);
172 #endif
173         throw std::system_error(error,
174                                 std::system_category(),
175                                 "connect() failed (" + hostname + " " +
176                                         std::to_string(port) + ")");
177     }
178 
179     // Socket is connected and ready to use
180     return sfd;
181 }
182 
new_socket(const std::string& host, in_port_t port, sa_family_t family)183 SOCKET cb::net::new_socket(const std::string& host,
184                            in_port_t port,
185                            sa_family_t family) {
186     struct addrinfo hints = {};
187     hints.ai_flags = AI_PASSIVE;
188     hints.ai_protocol = IPPROTO_TCP;
189     hints.ai_socktype = SOCK_STREAM;
190     hints.ai_family = family;
191 
192     int error;
193     struct addrinfo* ai;
194     std::string hostname{host};
195 
196     if (hostname.empty() || hostname == "localhost") {
197         if (family == AF_INET) {
198             hostname.assign("127.0.0.1");
199         } else if (family == AF_INET6){
200             hostname.assign("::1");
201         } else if (family == AF_UNSPEC) {
202             hostname.assign("localhost");
203         }
204     }
205 
206     error = getaddrinfo(
207             hostname.c_str(), std::to_string(port).c_str(), &hints, &ai);
208 
209     if (error != 0) {
210         throw std::system_error(error,
211                                 std::system_category(),
212                                 "Failed to resolve address host: \"" +
213                                         hostname +
214                                         "\" Port: " + std::to_string(port));
215     }
216 
217     bool unit_tests = getenv("MEMCACHED_UNIT_TESTS") != nullptr;
218 
219     // Iterate over all of the entries returned by getaddrinfo
220     // and try to connect to them. Depending on the input data we
221     // might get multiple returns (ex: localhost with AF_UNSPEC returns
222     // both IPv4 and IPv6 address, and IPv4 could fail while IPv6
223     // might succeed.
224     for (auto* next = ai; next; next = next->ai_next) {
225         int retry = unit_tests ? 200 : 0;
226         do {
227             try {
228                 auto sfd = try_connect_socket(next, hostname, port);
229                 freeaddrinfo(ai);
230                 return sfd;
231             } catch (const std::system_error& error) {
232                 if (unit_tests) {
233                     std::cerr << "Failed building socket: " << error.what()
234                               << std::endl;
235 #ifndef WIN32
236                     const int WSAEADDRINUSE = EADDRINUSE;
237 #endif
238                     if (error.code().value() == WSAEADDRINUSE) {
239                         std::cerr << "EADDRINUSE.. backing off" << std::endl;
240                         std::this_thread::sleep_for(
241                                 std::chrono::milliseconds(10));
242                     } else {
243                         // Not subject for backoff and retry
244                         retry = 0;
245                     }
246                 }
247             }
248         } while (retry-- > 0);
249         // Try next entry returned from getaddinfo
250     }
251 
252     freeaddrinfo(ai);
253     return INVALID_SOCKET;
254 }
255 
new_ssl_socket( const std::string& host, in_port_t port, sa_family_t family, std::function<void(SSL_CTX*)> setup_ssl_ctx)256 std::tuple<SOCKET, SSL_CTX*, BIO*> cb::net::new_ssl_socket(
257         const std::string& host,
258         in_port_t port,
259         sa_family_t family,
260         std::function<void(SSL_CTX*)> setup_ssl_ctx) {
261     auto sock = cb::net::new_socket(host, port, family);
262     if (sock == INVALID_SOCKET) {
263         return std::tuple<SOCKET, SSL_CTX*, BIO*>{
264                 INVALID_SOCKET, nullptr, nullptr};
265     }
266 
267     /* we're connected */
268     auto* context = SSL_CTX_new(SSLv23_client_method());
269     if (context == nullptr) {
270         throw std::runtime_error("Failed to create openssl client context");
271     }
272 
273     if (setup_ssl_ctx) {
274         setup_ssl_ctx(context);
275     }
276 
277     // Ensure read/write operations only return after the
278     // handshake and successful completion.
279     SSL_CTX_set_mode(context, SSL_MODE_AUTO_RETRY);
280 
281     BIO* bio = BIO_new_ssl(context, 1);
282     BIO_push(bio, BIO_new_socket(gsl::narrow<int>(sock), 0));
283 
284     if (BIO_do_handshake(bio) <= 0) {
285         BIO_free_all(bio);
286         SSL_CTX_free(context);
287         throw std::runtime_error("Failed to do SSL handshake!");
288     }
289 
290     return std::tuple<SOCKET, SSL_CTX*, BIO*>{sock, context, bio};
291 }
292 
releaseSocket()293 SOCKET MemcachedConnection::releaseSocket() {
294     if (ssl) {
295         throw std::runtime_error("releaseSocket: Can't release SSL socket");
296     }
297     auto ret = sock;
298     sock = INVALID_SOCKET;
299     return ret;
300 }
301 
tls_protocol_to_options(const std::string& protocol)302 long tls_protocol_to_options(const std::string& protocol) {
303     /* MB-12359 - Disable SSLv2 & SSLv3 due to POODLE */
304     long disallow = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
305 
306     std::string minimum(protocol);
307     std::transform(minimum.begin(), minimum.end(), minimum.begin(), tolower);
308 
309     if (minimum.empty() || minimum == "tlsv1") {
310         disallow |= SSL_OP_NO_TLSv1_3 | SSL_OP_NO_TLSv1_2 | SSL_OP_NO_TLSv1_1;
311     } else if (minimum == "tlsv1.1" || minimum == "tlsv1_1") {
312         disallow |= SSL_OP_NO_TLSv1_3 | SSL_OP_NO_TLSv1_2 | SSL_OP_NO_TLSv1;
313     } else if (minimum == "tlsv1.2" || minimum == "tlsv1_2") {
314         disallow |= SSL_OP_NO_TLSv1_3 | SSL_OP_NO_TLSv1_1 | SSL_OP_NO_TLSv1;
315     } else if (minimum == "tlsv1.3" || minimum == "tlsv1_3") {
316         disallow |= SSL_OP_NO_TLSv1_2 | SSL_OP_NO_TLSv1_1 | SSL_OP_NO_TLSv1;
317     } else {
318         throw std::invalid_argument("Unknown protocol: " + minimum);
319     }
320 
321     return disallow;
322 }
323 
connect()324 void MemcachedConnection::connect() {
325     if (bio != nullptr) {
326         BIO_free_all(bio);
327         bio = nullptr;
328     }
329 
330     if (context != nullptr) {
331         SSL_CTX_free(context);
332     }
333 
334     if (sock != INVALID_SOCKET) {
335         cb::net::shutdown(sock, SHUT_RDWR);
336         cb::net::closesocket(sock);
337         sock = INVALID_SOCKET;
338     }
339 
340     if (ssl) {
341         std::tie(sock, context, bio) = cb::net::new_ssl_socket(
342                 host, port, family, [this](SSL_CTX* context) {
343                     if (!tls_protocol.empty()) {
344                         SSL_CTX_set_options(
345                                 context, tls_protocol_to_options(tls_protocol));
346                     }
347 
348                     if (SSL_CTX_set_ciphersuites(context,
349                                                  tls13_ciphers.c_str()) == 0 &&
350                         !tls13_ciphers.empty()) {
351                         throw std::runtime_error(
352                                 "Failed to select a cipher suite from: " +
353                                 tls13_ciphers);
354                     }
355 
356                     if (SSL_CTX_set_cipher_list(context,
357                                                 tls12_ciphers.c_str()) == 0 &&
358                         !tls12_ciphers.empty()) {
359                         throw std::runtime_error(
360                                 "Failed to select a cipher suite from: " +
361                                 tls12_ciphers);
362                     }
363 
364                     if (!ssl_cert_file.empty() && !ssl_key_file.empty()) {
365                         if (!SSL_CTX_use_certificate_file(context,
366                                                           ssl_cert_file.c_str(),
367                                                           SSL_FILETYPE_PEM) ||
368                             !SSL_CTX_use_PrivateKey_file(context,
369                                                          ssl_key_file.c_str(),
370                                                          SSL_FILETYPE_PEM) ||
371                             !SSL_CTX_check_private_key(context)) {
372                             std::vector<char> ssl_err(1024);
373                             ERR_error_string_n(ERR_get_error(),
374                                                ssl_err.data(),
375                                                ssl_err.size());
376                             SSL_CTX_free(context);
377                             throw std::runtime_error(
378                                     std::string("Failed to use SSL cert and "
379                                                 "key: ") +
380                                     ssl_err.data());
381                         }
382                     }
383                 });
384     } else {
385         sock = cb::net::new_socket(host, port, family);
386     }
387 
388     if (sock == INVALID_SOCKET) {
389         auto error = cb::net::get_socket_error();
390         std::string msg("Failed to connect to: ");
391         if (family == AF_INET || family == AF_UNSPEC) {
392             if (host.empty()) {
393                 msg += "localhost:";
394             } else {
395                 msg += host + ":";
396             }
397         } else {
398             if (host.empty()) {
399                 msg += "[::1]:";
400             } else {
401                 msg += "[" + host + "]:";
402             }
403         }
404         msg.append(std::to_string(port));
405         throw std::system_error(error, std::system_category(), msg);
406     }
407 
408     bool unitTests = getenv("MEMCACHED_UNIT_TESTS") != nullptr;
409     if (!ssl && unitTests) {
410         // Enable LINGER with zero timeout. This changes the
411         // behaviour of close() - any unsent data will be
412         // discarded, and the connection will be immediately
413         // closed with a RST, and is immediately destroyed.  This
414         // has the advantage that the socket doesn't enter
415         // TIME_WAIT; and hence doesn't consume an emphemeral port
416         // until it times out (default 60s).
417         //
418         // By using LINGER we (hopefully!) avoid issues in CV jobs
419         // where ephemeral ports are exhausted and hence tests
420         // intermittently fail. One minor downside the RST
421         // triggers a warning in the server side logs: 'read
422         // error: Connection reset by peer'.
423         //
424         // Note that this isn't enabled for SSL sockets, which don't
425         // appear to be happy with having the underlying socket closed
426         // immediately; I suspect due to the additional out-of-band
427         // messages SSL may send/recv in addition to normal traffic.
428         struct linger sl {};
429         sl.l_onoff = 1;
430         sl.l_linger = 0;
431         cb::net::setsockopt(sock,
432                             SOL_SOCKET,
433                             SO_LINGER,
434                             reinterpret_cast<const void*>(&sl),
435                             sizeof(sl));
436     }
437 }
438 
sendBufferSsl(cb::const_byte_buffer buf)439 void MemcachedConnection::sendBufferSsl(cb::const_byte_buffer buf) {
440     const auto* data = reinterpret_cast<const char*>(buf.data());
441     cb::const_byte_buffer::size_type nbytes = buf.size();
442     cb::const_byte_buffer::size_type offset = 0;
443 
444     while (offset < nbytes) {
445         int nw = BIO_write(
446                 bio, data + offset, gsl::narrow<int>(nbytes - offset));
447         if (nw <= 0) {
448             if (BIO_should_retry(bio) == 0) {
449                 throw std::runtime_error(
450                         "Failed to write data, BIO_write returned " +
451                         std::to_string(nw));
452             }
453         } else {
454             offset += nw;
455         }
456     }
457 }
458 
sendBufferSsl(const std::vector<iovec>& list)459 void MemcachedConnection::sendBufferSsl(const std::vector<iovec>& list) {
460     for (auto buf : list) {
461         sendBufferSsl({reinterpret_cast<uint8_t*>(buf.iov_base), buf.iov_len});
462     }
463 }
464 
sendBufferPlain(cb::const_byte_buffer buf)465 void MemcachedConnection::sendBufferPlain(cb::const_byte_buffer buf) {
466     const auto* data = reinterpret_cast<const char*>(buf.data());
467     cb::const_byte_buffer::size_type nbytes = buf.size();
468     cb::const_byte_buffer::size_type offset = 0;
469 
470     while (offset < nbytes) {
471         auto nw = cb::net::send(sock, data + offset, nbytes - offset, 0);
472         if (nw <= 0) {
473             throw std::system_error(
474                     cb::net::get_socket_error(),
475                     std::system_category(),
476                     "MemcachedConnection::sendFramePlain: failed to send data");
477         } else {
478             offset += nw;
479         }
480     }
481 }
482 
sendBufferPlain(const std::vector<iovec>& iov)483 void MemcachedConnection::sendBufferPlain(const std::vector<iovec>& iov) {
484     // Calculate total size.
485     int bytes_remaining = 0;
486     for (const auto& io : iov) {
487         bytes_remaining += int(io.iov_len);
488     }
489 
490     // Encode sendmsg() message header.
491     msghdr msg{};
492     // sendmsg() doesn't actually change the value of msg_iov; but as
493     // it's a C API it doesn't have a const modifier. Therefore need
494     // to cast away const.
495     msg.msg_iov = const_cast<iovec*>(iov.data());
496     msg.msg_iovlen = int(iov.size());
497 
498     // repeatedly call sendmsg() until the complete payload has been
499     // transmitted.
500     for (;;) {
501         auto bytes_sent = cb::net::sendmsg(sock, &msg, 0);
502         if (bytes_sent < 0) {
503             throw std::system_error(cb::net::get_socket_error(),
504                                     std::system_category(),
505                                     "MemcachedConnection::sendBufferPlain: "
506                                     "sendmsg() failed to send data");
507         }
508 
509         bytes_remaining -= bytes_sent;
510         if (bytes_remaining == 0) {
511             // All data sent.
512             return;
513         }
514 
515         // Partial send. Remove the completed iovec entries from the
516         // list of pending writes.
517         while ((msg.msg_iovlen > 0) &&
518                (bytes_sent >= ssize_t(msg.msg_iov->iov_len))) {
519             // Complete element consumed; update msg_iov / iovlen to next
520             // element.
521             bytes_sent -= (ssize_t)msg.msg_iov->iov_len;
522             msg.msg_iovlen--;
523             msg.msg_iov++;
524         }
525 
526         // Might have written just part of the last iovec entry;
527         // adjust it so the next write will do the rest.
528         if (bytes_sent > 0) {
529             msg.msg_iov->iov_base =
530                     (void*)((unsigned char*)msg.msg_iov->iov_base + bytes_sent);
531             msg.msg_iov->iov_len -= bytes_sent;
532         }
533     }
534 }
535 
readSsl(Frame& frame, size_t bytes)536 void MemcachedConnection::readSsl(Frame& frame, size_t bytes) {
537     Frame::size_type offset = frame.payload.size();
538     frame.payload.resize(bytes + offset);
539     char* data = reinterpret_cast<char*>(frame.payload.data()) + offset;
540 
541     size_t total = 0;
542 
543     while (total < bytes) {
544         int nr = BIO_read(bio, data + total, gsl::narrow<int>(bytes - total));
545         if (nr <= 0) {
546             if (BIO_should_retry(bio) == 0) {
547                 throw std::runtime_error(
548                         "Failed to read data, BIO_read returned " +
549                         std::to_string(nr));
550             }
551         } else {
552             total += nr;
553         }
554     }
555 }
556 
readPlain(Frame& frame, size_t bytes)557 void MemcachedConnection::readPlain(Frame& frame, size_t bytes) {
558     Frame::size_type offset = frame.payload.size();
559     frame.payload.resize(bytes + offset);
560     char* data = reinterpret_cast<char*>(frame.payload.data()) + offset;
561 
562     size_t total = 0;
563 
564     while (total < bytes) {
565         auto nr = cb::net::recv(sock, data + total, bytes - total, 0);
566         if (nr <= 0) {
567             auto error = cb::net::get_socket_error();
568             if (nr == 0) {
569                 // nr == 0 means that the other end closed the connection.
570                 // Given that we expected to read more data, let's throw
571                 // an connection reset exception
572                 error = ECONNRESET;
573             }
574 
575             throw std::system_error(error, std::system_category(),
576                                     "MemcachedConnection::readPlain: failed to read data");
577         } else {
578             total += nr;
579         }
580     }
581 }
582 
sendFrame(const Frame& frame)583 void MemcachedConnection::sendFrame(const Frame& frame) {
584     sendBuffer({frame.payload.data(), frame.payload.size()});
585 }
586 
sendBuffer(const std::vector<iovec>& list)587 void MemcachedConnection::sendBuffer(const std::vector<iovec>& list) {
588     if (packet_dump) {
589         std::vector<uint8_t> blob;
590         for (auto& entry : list) {
591             const auto* ptr = static_cast<const uint8_t*>(entry.iov_base);
592             std::copy(ptr, ptr + entry.iov_len, std::back_inserter(blob));
593         }
594         try {
595             cb::mcbp::dumpStream({blob.data(), blob.size()}, std::cerr);
596         } catch (const std::exception&) {
597             // ignore..
598         }
599     }
600 
601     if (ssl) {
602         sendBufferSsl(list);
603     } else {
604         sendBufferPlain(list);
605     }
606 }
607 
sendBuffer(cb::const_byte_buffer buf)608 void MemcachedConnection::sendBuffer(cb::const_byte_buffer buf) {
609     if (packet_dump) {
610         try {
611             cb::mcbp::dumpStream(buf, std::cerr);
612         } catch (const std::exception&) {
613             // ignore..
614         }
615     }
616     if (ssl) {
617         sendBufferSsl(buf);
618     } else {
619         sendBufferPlain(buf);
620     }
621 }
622 
sendPartialFrame(Frame& frame, Frame::size_type length)623 void MemcachedConnection::sendPartialFrame(Frame& frame,
624                                            Frame::size_type length) {
625     // Move the remainder to a new frame.
626     auto rem_first = frame.payload.begin() + length;
627     auto rem_last = frame.payload.end();
628     std::vector<uint8_t> remainder;
629     std::copy(rem_first, rem_last, std::back_inserter(remainder));
630     frame.payload.erase(rem_first, rem_last);
631 
632     // Send the partial frame.
633     sendFrame(frame);
634 
635     // Swap the old payload with the remainder.
636     frame.payload.swap(remainder);
637 }
638 
read(Frame& frame, size_t bytes)639 void MemcachedConnection::read(Frame& frame, size_t bytes) {
640     if (ssl) {
641         readSsl(frame, bytes);
642     } else {
643         readPlain(frame, bytes);
644     }
645 }
646 
stats(const std::string& subcommand)647 nlohmann::json MemcachedConnection::stats(const std::string& subcommand) {
648     nlohmann::json ret;
649     stats(
650             [&ret](const std::string& key, const std::string& value) -> void {
651                 if (value.empty()) {
652                     ret[key] = "";
653                     return;
654                 }
655                 try {
656                     auto v = nlohmann::json::parse(value);
657                     ret[key] = v;
658                 } catch (const nlohmann::json::exception&) {
659                     ret[key] = value;
660                 }
661             },
662             subcommand);
663     return ret;
664 }
665 
setSslCertFile(const std::string& file)666 void MemcachedConnection::setSslCertFile(const std::string& file)  {
667     if (file.empty()) {
668         ssl_cert_file.clear();
669         return;
670     }
671     auto path = file;
672     cb::io::sanitizePath(path);
673     if (!cb::io::isFile(path)) {
674         throw std::system_error(std::make_error_code(std::errc::no_such_file_or_directory),
675                                 "Can't use [" + path + "]");
676     }
677     ssl_cert_file = path;
678 }
679 
setSslKeyFile(const std::string& file)680 void MemcachedConnection::setSslKeyFile(const std::string& file) {
681     if (file.empty()) {
682         ssl_key_file.clear();
683         return;
684     }
685     auto path = file;
686     cb::io::sanitizePath(path);
687     if (!cb::io::isFile(path)) {
688         throw std::system_error(std::make_error_code(std::errc::no_such_file_or_directory),
689                                 "Can't use [" + path + "]");
690     }
691     ssl_key_file = path;
692 }
693 
setTlsProtocol(std::string protocol)694 void MemcachedConnection::setTlsProtocol(std::string protocol) {
695     tls_protocol = std::move(protocol);
696 }
697 
setTls12Ciphers(std::string ciphers)698 void MemcachedConnection::setTls12Ciphers(std::string ciphers) {
699     tls12_ciphers = std::move(ciphers);
700 }
701 
setTls13Ciphers(std::string ciphers)702 void MemcachedConnection::setTls13Ciphers(std::string ciphers) {
703     tls13_ciphers = std::move(ciphers);
704 }
705 
to_frame(const BinprotCommand& command)706 static Frame to_frame(const BinprotCommand& command) {
707     Frame frame;
708     command.encode(frame.payload);
709     return frame;
710 }
711 
clone()712 std::unique_ptr<MemcachedConnection> MemcachedConnection::clone() {
713     auto result = std::make_unique<MemcachedConnection>(
714             this->host, this->port, this->family, this->ssl);
715     result->auto_retry_tmpfail = this->auto_retry_tmpfail;
716     result->setSslCertFile(this->ssl_cert_file);
717     result->setSslKeyFile(this->ssl_key_file);
718     result->connect();
719     result->applyFeatures("", this->effective_features);
720     return result;
721 }
722 
recvFrame(Frame& frame)723 void MemcachedConnection::recvFrame(Frame& frame) {
724     frame.reset();
725     // A memcached packet starts with a fixed header
726     MemcachedConnection::read(frame, sizeof(cb::mcbp::Header));
727 
728     auto magic = cb::mcbp::Magic(frame.payload.at(0));
729     if (magic != cb::mcbp::Magic::ClientRequest &&
730         magic != cb::mcbp::Magic::ClientResponse &&
731         magic != cb::mcbp::Magic::ServerRequest &&
732         magic != cb::mcbp::Magic::ServerResponse &&
733         magic != cb::mcbp::Magic::AltClientResponse) {
734         throw std::runtime_error("Invalid magic received: " +
735                                  std::to_string(frame.payload.at(0)));
736     }
737 
738     const auto* header =
739             reinterpret_cast<const cb::mcbp::Header*>(frame.payload.data());
740     MemcachedConnection::read(frame, header->getBodylen());
741     if (packet_dump) {
742         cb::mcbp::dump(frame.payload.data(), std::cerr);
743     }
744 }
745 
sendCommand(const BinprotCommand& command)746 void MemcachedConnection::sendCommand(const BinprotCommand& command) {
747     traceData.reset();
748 
749     auto encoded = command.encode();
750 
751     // encoded contains the message header (as owning vector<uint8_t>),
752     // plus a variable number of (non-owning) byte buffers. Create
753     // a single vector of byte buffers for all; then send in a single
754     // sendmsg() call (to avoid copying any data), with a single syscall.
755 
756     // Perf: this function previously used multiple calls to
757     // sendBuffer() (one per header / buffer) to send the data without
758     // copying / re-forming it. While this does reduce copying cost; it requires
759     // one send() syscall per chunk. Benchmarks show that is actually
760     // *more* expensive overall (particulary when measuring server
761     // performance) as the server can read the first header chunk;
762     // then attempts to read the body which hasn't been delievered yet
763     // and hence has to go around the libevent loop again to read the
764     // body.
765 
766     std::vector<iovec> message;
767     iovec iov{};
768     iov.iov_base = encoded.header.data();
769     iov.iov_len = encoded.header.size();
770     message.push_back(iov);
771     for (auto buf : encoded.bufs) {
772         iov.iov_base = const_cast<uint8_t*>(buf.data());
773         iov.iov_len = buf.size();
774         message.push_back(iov);
775     }
776 
777     sendBuffer(message);
778 }
779 
recvResponse(BinprotResponse& response)780 void MemcachedConnection::recvResponse(BinprotResponse& response) {
781     Frame frame;
782     traceData.reset();
783     recvFrame(frame);
784     response.assign(std::move(frame.payload));
785     traceData = response.getTracingData();
786 }
787 
authenticate(const std::string& username, const std::string& password, const std::string& mech)788 void MemcachedConnection::authenticate(const std::string& username,
789                                        const std::string& password,
790                                        const std::string& mech) {
791     cb::sasl::client::ClientContext client(
792             [username]() -> std::string { return username; },
793             [password]() -> std::string { return password; },
794             mech);
795     auto client_data = client.start();
796 
797     if (client_data.first != cb::sasl::Error::OK) {
798         throw std::runtime_error(std::string("cbsasl_client_start (") +
799                                  std::string(client.getName()) +
800                                  std::string("): ") +
801                                  ::to_string(client_data.first));
802     }
803 
804     BinprotSaslAuthCommand authCommand;
805     authCommand.setChallenge(client_data.second);
806     authCommand.setMechanism(client.getName());
807     auto response = execute(authCommand);
808 
809     while (response.getStatus() == cb::mcbp::Status::AuthContinue) {
810         auto respdata = response.getData();
811         client_data =
812                 client.step({reinterpret_cast<const char*>(respdata.data()),
813                              respdata.size()});
814         if (client_data.first != cb::sasl::Error::OK &&
815             client_data.first != cb::sasl::Error::CONTINUE) {
816             reconnect();
817             throw std::runtime_error(std::string("cbsasl_client_step: ") +
818                                      ::to_string(client_data.first));
819         }
820 
821         BinprotSaslStepCommand stepCommand;
822         stepCommand.setMechanism(client.getName());
823         stepCommand.setChallenge(client_data.second);
824         response = execute(stepCommand);
825     }
826 
827     if (!response.isSuccess()) {
828         throw ConnectionError("Authentication failed", response);
829     }
830 }
831 
createBucket(const std::string& name, const std::string& config, BucketType type)832 void MemcachedConnection::createBucket(const std::string& name,
833                                        const std::string& config,
834                                        BucketType type) {
835     std::string module;
836     switch (type) {
837     case BucketType::Memcached:
838         module.assign("default_engine.so");
839         break;
840     case BucketType::EWouldBlock:
841         module.assign("ewouldblock_engine.so");
842         break;
843     case BucketType::Couchbase:
844         module.assign("ep.so");
845         break;
846     default:
847         throw std::runtime_error("Not implemented");
848     }
849 
850     BinprotCreateBucketCommand command(name.c_str());
851     command.setConfig(module, config);
852 
853     const auto response = execute(command);
854     if (!response.isSuccess()) {
855         throw ConnectionError("Create bucket failed", response);
856     }
857 }
858 
deleteBucket(const std::string& name)859 void MemcachedConnection::deleteBucket(const std::string& name) {
860     BinprotGenericCommand command(cb::mcbp::ClientOpcode::DeleteBucket, name);
861     const auto response = execute(command);
862     if (!response.isSuccess()) {
863         throw ConnectionError("Delete bucket failed", response);
864     }
865 }
866 
selectBucket(const std::string& name)867 void MemcachedConnection::selectBucket(const std::string& name) {
868     BinprotGenericCommand command(cb::mcbp::ClientOpcode::SelectBucket, name);
869     const auto response = execute(command);
870     if (!response.isSuccess()) {
871         throw ConnectionError(
872                 std::string{"Select bucket [" + name + "] failed"}, response);
873     }
874 }
875 
to_string()876 std::string MemcachedConnection::to_string() {
877     std::string ret("Memcached connection ");
878     ret.append(std::to_string(port));
879     if (family == AF_INET6) {
880         ret.append("[::1]:");
881     } else {
882         ret.append("127.0.0.1:");
883     }
884 
885     ret.append(std::to_string(port));
886 
887     if (ssl) {
888         ret.append(" ssl");
889     }
890 
891     return ret;
892 }
893 
listBuckets( GetFrameInfoFunction getFrameInfo)894 std::vector<std::string> MemcachedConnection::listBuckets(
895         GetFrameInfoFunction getFrameInfo) {
896     BinprotGenericCommand command(cb::mcbp::ClientOpcode::ListBuckets);
897     applyFrameInfos(command, getFrameInfo);
898     const auto response = execute(command);
899     if (!response.isSuccess()) {
900         throw ConnectionError("List bucket failed", response);
901     }
902 
903     std::vector<std::string> ret;
904 
905     // the value contains a list of bucket names separated by space.
906     std::istringstream iss(response.getDataString());
907     std::copy(std::istream_iterator<std::string>(iss),
908               std::istream_iterator<std::string>(),
909               std::back_inserter(ret));
910 
911     return ret;
912 }
913 
get( const std::string& id, Vbid vbucket, std::function<std::vector<std::unique_ptr<FrameInfo>>()> getFrameInfo)914 Document MemcachedConnection::get(
915         const std::string& id,
916         Vbid vbucket,
917         std::function<std::vector<std::unique_ptr<FrameInfo>>()> getFrameInfo) {
918     BinprotGetCommand command;
919     command.setKey(id);
920     command.setVBucket(vbucket);
921     applyFrameInfos(command, getFrameInfo);
922 
923     const auto response = BinprotGetResponse(execute(command));
924     if (!response.isSuccess()) {
925         throw ConnectionError("Failed to get: " + id, response.getStatus());
926     }
927 
928     Document ret;
929     ret.info.flags = response.getDocumentFlags();
930     ret.info.cas = response.getCas();
931     ret.info.id = id;
932     ret.info.datatype = response.getResponse().getDatatype();
933     ret.value = response.getDataString();
934     return ret;
935 }
936 
mget( const std::vector<std::pair<const std::string, Vbid>>& id, std::function<void(std::unique_ptr<Document>&)> documentCallback, std::function<void(const std::string&, const cb::mcbp::Response&)> errorCallback, GetFrameInfoFunction getFrameInfo)937 void MemcachedConnection::mget(
938         const std::vector<std::pair<const std::string, Vbid>>& id,
939         std::function<void(std::unique_ptr<Document>&)> documentCallback,
940         std::function<void(const std::string&, const cb::mcbp::Response&)>
941                 errorCallback,
942         GetFrameInfoFunction getFrameInfo) {
943     using cb::mcbp::ClientOpcode;
944 
945     // One of the motivations for this method is to be able to test a
946     // pipeline of commands (to get them reordered on the server if OoO
947     // is enabled). Sending each command as an individual packet may
948     // cause the server to completely execute the command before it goes
949     // back into the read state and sees the next command.
950     std::vector<uint8_t> pipeline;
951 
952     int ii = 0;
953     for (const auto& doc : id) {
954         BinprotGetCommand command;
955         command.setOp(ClientOpcode::Getq); // Use the quiet one
956         command.setKey(doc.first);
957         command.setVBucket(doc.second);
958         command.setOpaque(ii++);
959         applyFrameInfos(command, getFrameInfo);
960 
961         std::vector<uint8_t> cmd;
962         command.encode(cmd);
963         std::copy(cmd.begin(), cmd.end(), std::back_inserter(pipeline));
964     }
965 
966     // Add a noop command to terminate the sequence
967     {
968         BinprotGenericCommand command{ClientOpcode::Noop};
969         std::vector<uint8_t> cmd;
970         command.encode(cmd);
971         std::copy(cmd.begin(), cmd.end(), std::back_inserter(pipeline));
972     }
973 
974     // Now send the pipeline to the other end!
975     sendBuffer(cb::const_byte_buffer{pipeline.data(), pipeline.size()});
976 
977     // read until I see the noop response
978     auto done = false;
979     do {
980         BinprotResponse rsp;
981         recvResponse(rsp);
982         auto opcode = rsp.getOp();
983         if (opcode == ClientOpcode::Noop) {
984             done = true;
985         } else if (opcode != ClientOpcode::Getq) {
986             throw std::runtime_error(
987                     "MemcachedConnection::mget: Received unexpected opcode");
988         } else {
989             BinprotGetResponse getResponse(std::move(rsp));
990             auto opaque = getResponse.getResponse().getOpaque();
991             if (opaque >= id.size()) {
992                 throw std::runtime_error(
993                         "MemcachedConnection::mget: Invalid opaque received");
994             }
995             const auto& key = id[opaque].first;
996 
997             if (getResponse.isSuccess()) {
998                 auto doc = std::make_unique<Document>();
999                 doc->info.flags = getResponse.getDocumentFlags();
1000                 doc->info.cas = getResponse.getCas();
1001                 doc->info.id = key;
1002                 doc->info.datatype = getResponse.getResponse().getDatatype();
1003                 doc->value = getResponse.getDataString();
1004                 documentCallback(doc);
1005             } else if (errorCallback) {
1006                 errorCallback(key, getResponse.getResponse());
1007             }
1008         }
1009     } while (!done);
1010 }
1011 
encodeCmdGet(const std::string& id, Vbid vbucket)1012 Frame MemcachedConnection::encodeCmdGet(const std::string& id, Vbid vbucket) {
1013     BinprotGetCommand command;
1014     command.setKey(id);
1015     command.setVBucket(vbucket);
1016     return to_frame(command);
1017 }
1018 
mutate(const DocumentInfo& info, Vbid vbucket, cb::const_byte_buffer value, MutationType type, GetFrameInfoFunction getFrameInfo)1019 MutationInfo MemcachedConnection::mutate(const DocumentInfo& info,
1020                                          Vbid vbucket,
1021                                          cb::const_byte_buffer value,
1022                                          MutationType type,
1023                                          GetFrameInfoFunction getFrameInfo) {
1024     BinprotMutationCommand command;
1025     command.setDocumentInfo(info);
1026     command.addValueBuffer(value);
1027     command.setVBucket(vbucket);
1028     command.setMutationType(type);
1029     applyFrameInfos(command, getFrameInfo);
1030 
1031     const auto response = BinprotMutationResponse(execute(command));
1032     if (!response.isSuccess()) {
1033         throw ConnectionError("Failed to store " + info.id,
1034                               response.getStatus());
1035     }
1036 
1037     return response.getMutationInfo();
1038 }
1039 
store(const std::string& id, Vbid vbucket, std::string value, cb::mcbp::Datatype datatype, GetFrameInfoFunction getFrameInfo)1040 MutationInfo MemcachedConnection::store(const std::string& id,
1041                                         Vbid vbucket,
1042                                         std::string value,
1043                                         cb::mcbp::Datatype datatype,
1044                                         GetFrameInfoFunction getFrameInfo) {
1045     Document doc{};
1046     doc.value = std::move(value);
1047     doc.info.id = id;
1048     doc.info.datatype = datatype;
1049     return mutate(doc, vbucket, MutationType::Set, getFrameInfo);
1050 }
1051 
stats( std::function<void(const std::string&, const std::string&)> callback, const std::string& group)1052 void MemcachedConnection::stats(
1053         std::function<void(const std::string&, const std::string&)> callback,
1054         const std::string& group) {
1055     BinprotGenericCommand cmd(cb::mcbp::ClientOpcode::Stat, group);
1056     sendCommand(cmd);
1057 
1058     int counter = 0;
1059 
1060     while (true) {
1061         BinprotResponse response;
1062         recvResponse(response);
1063 
1064         if (!response.isSuccess()) {
1065             throw ConnectionError("Stats failed", response);
1066         }
1067 
1068         if (!response.getBodylen()) {
1069             break;
1070         }
1071 
1072         std::string key = response.getKeyString();
1073 
1074         if (key.empty()) {
1075             key = std::to_string(counter++);
1076         }
1077         callback(key, response.getDataString());
1078     }
1079 }
1080 
statsMap( const std::string& subcommand)1081 std::map<std::string, std::string> MemcachedConnection::statsMap(
1082         const std::string& subcommand) {
1083     std::map<std::string, std::string> ret;
1084     stats([&ret](const std::string& key,
1085                  const std::string& value) -> void { ret[key] = value; },
1086           subcommand);
1087     return ret;
1088 }
1089 
configureEwouldBlockEngine(const EWBEngineMode& mode, ENGINE_ERROR_CODE err_code, uint32_t value, const std::string& key)1090 void MemcachedConnection::configureEwouldBlockEngine(const EWBEngineMode& mode,
1091                                                      ENGINE_ERROR_CODE err_code,
1092                                                      uint32_t value,
1093                                                      const std::string& key) {
1094     cb::mcbp::request::EWB_Payload payload;
1095     payload.setMode(uint32_t(mode));
1096     payload.setValue(uint32_t(value));
1097     payload.setInjectError(uint32_t(err_code));
1098 
1099     std::vector<uint8_t> buffer(sizeof(cb::mcbp::Request) +
1100                                 sizeof(cb::mcbp::request::EWB_Payload) +
1101                                 key.size());
1102     cb::mcbp::RequestBuilder builder({buffer.data(), buffer.size()});
1103     builder.setMagic(cb::mcbp::Magic::ClientRequest);
1104     builder.setOpcode(cb::mcbp::ClientOpcode::EwouldblockCtl);
1105     builder.setExtras(
1106             {reinterpret_cast<const uint8_t*>(&payload), sizeof(payload)});
1107     builder.setKey({reinterpret_cast<const uint8_t*>(key.data()), key.size()});
1108 
1109     Frame frame;
1110     frame.payload = std::move(buffer);
1111 
1112     auto response = execute(frame);
1113     auto* bytes = response.payload.data();
1114     auto* rsp = reinterpret_cast<protocol_binary_response_no_extras*>(bytes);
1115     auto& header = rsp->message.header.response;
1116     if (header.getStatus() != cb::mcbp::Status::Success) {
1117         throw ConnectionError("Failed to configure ewouldblock engine",
1118                               header.getStatus());
1119     }
1120 }
1121 
reloadAuditConfiguration( GetFrameInfoFunction getFrameInfo)1122 void MemcachedConnection::reloadAuditConfiguration(
1123         GetFrameInfoFunction getFrameInfo) {
1124     BinprotGenericCommand command(cb::mcbp::ClientOpcode::AuditConfigReload);
1125     applyFrameInfos(command, getFrameInfo);
1126     const auto response = execute(command);
1127     if (!response.isSuccess()) {
1128         throw ConnectionError("Failed to reload audit configuration", response);
1129     }
1130 }
1131 
hello(const std::string& userAgent, const std::string& userAgentVersion, const std::string& comment)1132 void MemcachedConnection::hello(const std::string& userAgent,
1133                                 const std::string& userAgentVersion,
1134                                 const std::string& comment) {
1135     applyFeatures(userAgent + " " + userAgentVersion, effective_features);
1136 }
1137 
applyFeatures(const std::string& agent, const Featureset& featureset)1138 void MemcachedConnection::applyFeatures(const std::string& agent,
1139                                         const Featureset& featureset) {
1140     BinprotHelloCommand command(agent);
1141     for (const auto& feature : featureset) {
1142         command.enableFeature(cb::mcbp::Feature(feature), true);
1143     }
1144 
1145     const auto response = BinprotHelloResponse(execute(command));
1146     if (!response.isSuccess()) {
1147         throw ConnectionError("Failed to say hello", response);
1148     }
1149 
1150     effective_features.clear();
1151     for (const auto& feature : response.getFeatures()) {
1152         effective_features.insert(uint16_t(feature));
1153     }
1154 }
1155 
setFeatures( const std::string& agent, const std::vector<cb::mcbp::Feature>& features)1156 void MemcachedConnection::setFeatures(
1157         const std::string& agent,
1158         const std::vector<cb::mcbp::Feature>& features) {
1159     BinprotHelloCommand command(agent);
1160     for (const auto& feature : features) {
1161         command.enableFeature(cb::mcbp::Feature(feature), true);
1162     }
1163 
1164     const auto response = BinprotHelloResponse(execute(command));
1165     if (!response.isSuccess()) {
1166         throw ConnectionError("Failed to say hello", response);
1167     }
1168 
1169     effective_features.clear();
1170     for (const auto& feature : response.getFeatures()) {
1171         effective_features.insert(uint16_t(feature));
1172     }
1173 
1174     // Verify that I was able to set all of them
1175     std::stringstream ss;
1176     ss << "[";
1177 
1178     for (const auto& feature : features) {
1179         if (!hasFeature(feature)) {
1180             ss << ::to_string(feature) << ",";
1181         }
1182     }
1183 
1184     auto missing = ss.str();
1185     if (missing.size() > 1) {
1186         missing.back() = ']';
1187         throw std::runtime_error("Failed to enable: " + missing);
1188     }
1189 }
1190 
setFeature(cb::mcbp::Feature feature, bool enabled)1191 void MemcachedConnection::setFeature(cb::mcbp::Feature feature, bool enabled) {
1192     Featureset currFeatures = effective_features;
1193     if (enabled) {
1194         currFeatures.insert(uint16_t(feature));
1195     } else {
1196         currFeatures.erase(uint16_t(feature));
1197     }
1198 
1199     applyFeatures("mcbp", currFeatures);
1200 
1201     if (enabled && !hasFeature(feature)) {
1202         throw std::runtime_error("Failed to enable " + ::to_string(feature));
1203     } else if (!enabled && hasFeature(feature)) {
1204         throw std::runtime_error("Failed to disable " + ::to_string(feature));
1205     }
1206 }
1207 
getSaslMechanisms()1208 std::string MemcachedConnection::getSaslMechanisms() {
1209     BinprotGenericCommand command(cb::mcbp::ClientOpcode::SaslListMechs);
1210     const auto response = execute(command);
1211     if (!response.isSuccess()) {
1212         throw ConnectionError("Failed to fetch sasl mechanisms", response);
1213     }
1214 
1215     return response.getDataString();
1216 }
1217 
ioctl_get(const std::string& key, GetFrameInfoFunction getFrameInfo)1218 std::string MemcachedConnection::ioctl_get(const std::string& key,
1219                                            GetFrameInfoFunction getFrameInfo) {
1220     BinprotGenericCommand command(cb::mcbp::ClientOpcode::IoctlGet, key);
1221     applyFrameInfos(command, getFrameInfo);
1222 
1223     const auto response = execute(command);
1224     if (!response.isSuccess()) {
1225         throw ConnectionError("ioctl_get '" + key + "' failed", response);
1226     }
1227     return response.getDataString();
1228 }
1229 
ioctl_set(const std::string& key, const std::string& value, GetFrameInfoFunction getFrameInfo)1230 void MemcachedConnection::ioctl_set(const std::string& key,
1231                                     const std::string& value,
1232                                     GetFrameInfoFunction getFrameInfo) {
1233     BinprotGenericCommand command(cb::mcbp::ClientOpcode::IoctlSet, key, value);
1234     applyFrameInfos(command, getFrameInfo);
1235     const auto response = execute(command);
1236     if (!response.isSuccess()) {
1237         throw ConnectionError("ioctl_set '" + key + "' failed", response);
1238     }
1239 }
1240 
increment(const std::string& key, uint64_t delta, uint64_t initial, rel_time_t exptime, MutationInfo* info, GetFrameInfoFunction getFrameInfo)1241 uint64_t MemcachedConnection::increment(const std::string& key,
1242                                         uint64_t delta,
1243                                         uint64_t initial,
1244                                         rel_time_t exptime,
1245                                         MutationInfo* info,
1246                                         GetFrameInfoFunction getFrameInfo) {
1247     return incr_decr(cb::mcbp::ClientOpcode::Increment,
1248                      key,
1249                      delta,
1250                      initial,
1251                      exptime,
1252                      info,
1253                      getFrameInfo);
1254 }
1255 
decrement(const std::string& key, uint64_t delta, uint64_t initial, rel_time_t exptime, MutationInfo* info, GetFrameInfoFunction getFrameInfo)1256 uint64_t MemcachedConnection::decrement(const std::string& key,
1257                                         uint64_t delta,
1258                                         uint64_t initial,
1259                                         rel_time_t exptime,
1260                                         MutationInfo* info,
1261                                         GetFrameInfoFunction getFrameInfo) {
1262     return incr_decr(cb::mcbp::ClientOpcode::Decrement,
1263                      key,
1264                      delta,
1265                      initial,
1266                      exptime,
1267                      info,
1268                      getFrameInfo);
1269 }
1270 
incr_decr(cb::mcbp::ClientOpcode opcode, const std::string& key, uint64_t delta, uint64_t initial, rel_time_t exptime, MutationInfo* info, GetFrameInfoFunction getFrameInfo)1271 uint64_t MemcachedConnection::incr_decr(cb::mcbp::ClientOpcode opcode,
1272                                         const std::string& key,
1273                                         uint64_t delta,
1274                                         uint64_t initial,
1275                                         rel_time_t exptime,
1276                                         MutationInfo* info,
1277                                         GetFrameInfoFunction getFrameInfo) {
1278     const char* opcode_name =
1279             (opcode == cb::mcbp::ClientOpcode::Increment) ? "incr" : "decr";
1280 
1281     BinprotIncrDecrCommand command;
1282     command.setOp(opcode).setKey(key);
1283     command.setDelta(delta).setInitialValue(initial).setExpiry(exptime);
1284     applyFrameInfos(command, getFrameInfo);
1285 
1286     const auto response = BinprotIncrDecrResponse(execute(command));
1287     if (!response.isSuccess()) {
1288         throw ConnectionError(
1289                 std::string(opcode_name) + " \"" + key + "\" failed.",
1290                 response.getStatus());
1291     }
1292 
1293     if (response.getDatatype() != PROTOCOL_BINARY_RAW_BYTES) {
1294         throw ValidationError(
1295                 std::string(opcode_name) + " \"" + key +
1296                 "\"invalid - response has incorrect datatype (" +
1297                 mcbp::datatype::to_string(response.getDatatype()) + ")");
1298     }
1299 
1300     if (info != nullptr) {
1301         *info = response.getMutationInfo();
1302     }
1303     return response.getValue();
1304 }
1305 
remove(const std::string& key, Vbid vbucket, uint64_t cas, GetFrameInfoFunction getFrameInfo)1306 MutationInfo MemcachedConnection::remove(const std::string& key,
1307                                          Vbid vbucket,
1308                                          uint64_t cas,
1309                                          GetFrameInfoFunction getFrameInfo) {
1310     BinprotRemoveCommand command;
1311     command.setKey(key).setVBucket(vbucket);
1312     command.setVBucket(vbucket);
1313     command.setCas(cas);
1314     applyFrameInfos(command, getFrameInfo);
1315 
1316     const auto response = BinprotRemoveResponse(execute(command));
1317 
1318     if (!response.isSuccess()) {
1319         throw ConnectionError("Failed to remove: " + key, response.getStatus());
1320     }
1321 
1322     return response.getMutationInfo();
1323 }
1324 
get_and_lock(const std::string& id, Vbid vbucket, uint32_t lock_timeout, GetFrameInfoFunction getFrameInfo)1325 Document MemcachedConnection::get_and_lock(const std::string& id,
1326                                            Vbid vbucket,
1327                                            uint32_t lock_timeout,
1328                                            GetFrameInfoFunction getFrameInfo) {
1329     BinprotGetAndLockCommand command;
1330     command.setKey(id);
1331     command.setVBucket(vbucket);
1332     command.setLockTimeout(lock_timeout);
1333     applyFrameInfos(command, getFrameInfo);
1334 
1335     const auto response = BinprotGetAndLockResponse(execute(command));
1336 
1337     if (!response.isSuccess()) {
1338         throw ConnectionError("Failed to get: " + id, response.getStatus());
1339     }
1340 
1341     Document ret;
1342     ret.info.flags = response.getDocumentFlags();
1343     ret.info.cas = response.getCas();
1344     ret.info.id = id;
1345     ret.info.datatype = response.getResponse().getDatatype();
1346     ret.value = response.getDataString();
1347     return ret;
1348 }
1349 
getFailoverLog( Vbid vbucket, GetFrameInfoFunction getFrameInfo)1350 BinprotResponse MemcachedConnection::getFailoverLog(
1351         Vbid vbucket, GetFrameInfoFunction getFrameInfo) {
1352     BinprotGetFailoverLogCommand command;
1353     command.setVBucket(vbucket);
1354     applyFrameInfos(command, getFrameInfo);
1355 
1356     return execute(command);
1357 }
1358 
unlock(const std::string& id, Vbid vbucket, uint64_t cas, GetFrameInfoFunction getFrameInfo)1359 void MemcachedConnection::unlock(const std::string& id,
1360                                  Vbid vbucket,
1361                                  uint64_t cas,
1362                                  GetFrameInfoFunction getFrameInfo) {
1363     BinprotUnlockCommand command;
1364     command.setKey(id);
1365     command.setVBucket(vbucket);
1366     command.setCas(cas);
1367     applyFrameInfos(command, getFrameInfo);
1368 
1369     const auto response = execute(command);
1370     if (!response.isSuccess()) {
1371         throw ConnectionError("unlock(): " + id, response.getStatus());
1372     }
1373 }
1374 
dropPrivilege(cb::rbac::Privilege privilege, GetFrameInfoFunction getFrameInfo)1375 void MemcachedConnection::dropPrivilege(cb::rbac::Privilege privilege,
1376                                         GetFrameInfoFunction getFrameInfo) {
1377     BinprotGenericCommand command(cb::mcbp::ClientOpcode::DropPrivilege,
1378                                   cb::rbac::to_string(privilege));
1379     applyFrameInfos(command, getFrameInfo);
1380 
1381     const auto response = execute(command);
1382     if (!response.isSuccess()) {
1383         throw ConnectionError("dropPrivilege \"" +
1384                                       cb::rbac::to_string(privilege) +
1385                                       "\" failed.",
1386                               response.getStatus());
1387     }
1388 }
1389 
mutateWithMeta( Document& doc, Vbid vbucket, uint64_t cas, uint64_t seqno, uint32_t metaOption, std::vector<uint8_t> metaExtras, GetFrameInfoFunction getFrameInfo)1390 MutationInfo MemcachedConnection::mutateWithMeta(
1391         Document& doc,
1392         Vbid vbucket,
1393         uint64_t cas,
1394         uint64_t seqno,
1395         uint32_t metaOption,
1396         std::vector<uint8_t> metaExtras,
1397         GetFrameInfoFunction getFrameInfo) {
1398     BinprotSetWithMetaCommand swm(
1399             doc, vbucket, cas, seqno, metaOption, metaExtras);
1400     applyFrameInfos(swm, getFrameInfo);
1401 
1402     const auto response = BinprotMutationResponse(execute(swm));
1403     if (!response.isSuccess()) {
1404         throw ConnectionError("Failed to mutateWithMeta " + doc.info.id + " " +
1405                                       response.getDataString(),
1406                               response.getStatus());
1407     }
1408 
1409     return response.getMutationInfo();
1410 }
1411 
observeSeqno( Vbid vbid, uint64_t uuid, GetFrameInfoFunction getFrameInfo)1412 ObserveInfo MemcachedConnection::observeSeqno(
1413         Vbid vbid, uint64_t uuid, GetFrameInfoFunction getFrameInfo) {
1414     BinprotObserveSeqnoCommand observe(vbid, uuid);
1415     applyFrameInfos(observe, getFrameInfo);
1416 
1417     const auto response = BinprotObserveSeqnoResponse(execute(observe));
1418     if (!response.isSuccess()) {
1419         throw ConnectionError(std::string("Failed to observeSeqno for ") +
1420                                       vbid.to_string() + " uuid:" +
1421                                       std::to_string(uuid),
1422                               response.getStatus());
1423     }
1424     return response.info;
1425 }
1426 
enablePersistence(GetFrameInfoFunction getFrameInfo)1427 void MemcachedConnection::enablePersistence(GetFrameInfoFunction getFrameInfo) {
1428     BinprotGenericCommand command(cb::mcbp::ClientOpcode::StartPersistence);
1429     applyFrameInfos(command, getFrameInfo);
1430 
1431     const auto response = execute(command);
1432     if (!response.isSuccess()) {
1433         throw ConnectionError("Failed to enablePersistence ",
1434                               response.getStatus());
1435     }
1436 }
1437 
disablePersistence( GetFrameInfoFunction getFrameInfo)1438 void MemcachedConnection::disablePersistence(
1439         GetFrameInfoFunction getFrameInfo) {
1440     BinprotGenericCommand command(cb::mcbp::ClientOpcode::StopPersistence);
1441     applyFrameInfos(command, getFrameInfo);
1442     const auto response = execute(command);
1443     if (!response.isSuccess()) {
1444         throw ConnectionError("Failed to disablePersistence ",
1445                               response.getStatus());
1446     }
1447 }
1448 
getMeta( const std::string& key, Vbid vbucket, GetMetaVersion version, GetFrameInfoFunction getFrameInfo)1449 std::pair<cb::mcbp::Status, GetMetaResponse> MemcachedConnection::getMeta(
1450         const std::string& key,
1451         Vbid vbucket,
1452         GetMetaVersion version,
1453         GetFrameInfoFunction getFrameInfo) {
1454     BinprotGenericCommand cmd{cb::mcbp::ClientOpcode::GetMeta, key};
1455     cmd.setVBucket(vbucket);
1456     const std::vector<uint8_t> extras = {uint8_t(version)};
1457     cmd.setExtras(extras);
1458     applyFrameInfos(cmd, getFrameInfo);
1459 
1460     auto resp = execute(cmd);
1461 
1462     GetMetaResponse meta;
1463     const auto ext = resp.getResponse().getExtdata();
1464     memcpy(&meta, ext.data(), ext.size());
1465     meta.deleted = ntohl(meta.deleted);
1466     meta.expiry = ntohl(meta.expiry);
1467     meta.seqno = ntohll(meta.seqno);
1468 
1469     return std::make_pair(resp.getStatus(), meta);
1470 }
1471 
getRandomKey(Vbid vbucket)1472 Document MemcachedConnection::getRandomKey(Vbid vbucket) {
1473     BinprotGenericCommand cmd{cb::mcbp::ClientOpcode::GetRandomKey};
1474     cmd.setVBucket(vbucket);
1475     const auto response = BinprotGetResponse(execute(cmd));
1476     if (!response.isSuccess()) {
1477         throw ConnectionError("Failed getRandomKey", response.getStatus());
1478     }
1479 
1480     Document ret;
1481     ret.info.flags = response.getDocumentFlags();
1482     ret.info.cas = response.getCas();
1483     ret.info.id = response.getKeyString();
1484     ret.info.datatype = response.getResponse().getDatatype();
1485     ret.value = response.getDataString();
1486     return ret;
1487 }
1488 
setUnorderedExecutionMode(ExecutionMode mode)1489 void MemcachedConnection::setUnorderedExecutionMode(ExecutionMode mode) {
1490     switch (mode) {
1491     case ExecutionMode::Ordered:
1492         setFeature(cb::mcbp::Feature::UnorderedExecution, false);
1493         return;
1494     case ExecutionMode::Unordered:
1495         setFeature(cb::mcbp::Feature::UnorderedExecution, true);
1496         return;
1497     }
1498     throw std::invalid_argument("setUnorderedExecutionMode: Invalid mode");
1499 }
1500 
execute(const BinprotCommand &command)1501 BinprotResponse MemcachedConnection::execute(const BinprotCommand &command) {
1502     BinprotResponse response;
1503     backoff_execute([&command, &response, this]() -> bool {
1504         sendCommand(command);
1505         recvResponse(response);
1506         return !(auto_retry_tmpfail &&
1507                  response.getStatus() == cb::mcbp::Status::Etmpfail);
1508     });
1509     return response;
1510 }
1511 
execute(const Frame& frame)1512 Frame MemcachedConnection::execute(const Frame& frame) {
1513     Frame response;
1514     backoff_execute([&frame, &response, this]() -> bool {
1515         sendFrame(frame);
1516         recvFrame(response);
1517         return !(auto_retry_tmpfail && response.getResponse()->getStatus() ==
1518                                                cb::mcbp::Status::Etmpfail);
1519     });
1520     return response;
1521 }
1522 
backoff_execute(std::function<bool()> executor, std::chrono::milliseconds backoff, std::chrono::seconds timeout)1523 void MemcachedConnection::backoff_execute(std::function<bool()> executor,
1524                                           std::chrono::milliseconds backoff,
1525                                           std::chrono::seconds timeout) {
1526     using std::chrono::steady_clock;
1527     const auto wait_timeout = steady_clock::now() + timeout;
1528     do {
1529         if (executor()) {
1530             return;
1531         }
1532         std::this_thread::sleep_for(backoff);
1533     } while (steady_clock::now() < wait_timeout);
1534     throw std::runtime_error(
1535             "MemcachedConnection::backoff_executor: Timed out after waiting "
1536             "more than " +
1537             std::to_string(timeout.count()) + " seconds");
1538 }
1539 
evict(const std::string& key, Vbid vbucket, GetFrameInfoFunction getFrameInfo)1540 void MemcachedConnection::evict(const std::string& key,
1541                                 Vbid vbucket,
1542                                 GetFrameInfoFunction getFrameInfo) {
1543     backoff_execute([this, &key, &vbucket]() -> bool {
1544         BinprotGenericCommand cmd(cb::mcbp::ClientOpcode::EvictKey, key);
1545         cmd.setVBucket(vbucket);
1546         const auto rsp = execute(cmd);
1547         if (rsp.isSuccess()) {
1548             // Evicted
1549             return true;
1550         }
1551         if (rsp.getStatus() == cb::mcbp::Status::KeyEexists) {
1552             return false;
1553         }
1554 
1555         throw ConnectionError("evict: Failed to evict key \"" + key + "\"",
1556                               rsp.getStatus());
1557     });
1558 }
1559 
setVbucket(Vbid vbid, vbucket_state_t state, const nlohmann::json& payload, GetFrameInfoFunction getFrameInfo)1560 void MemcachedConnection::setVbucket(Vbid vbid,
1561                                      vbucket_state_t state,
1562                                      const nlohmann::json& payload,
1563                                      GetFrameInfoFunction getFrameInfo) {
1564     BinprotSetVbucketCommand command{vbid, state, payload};
1565     applyFrameInfos(command, getFrameInfo);
1566 
1567     auto rsp = execute(command);
1568     if (!rsp.isSuccess()) {
1569         throw ConnectionError("setVbucket: Faled to set state",
1570                               rsp.getStatus());
1571     }
1572 }
1573 
applyFrameInfos(BinprotCommand& command, GetFrameInfoFunction& getFrameInfo)1574 void MemcachedConnection::applyFrameInfos(BinprotCommand& command,
1575                                           GetFrameInfoFunction& getFrameInfo) {
1576     if (getFrameInfo) {
1577         auto frame_info = getFrameInfo();
1578         for (const auto& fi : frame_info) {
1579             command.addFrameInfo(*fi);
1580         }
1581     }
1582 }
1583 
1584 /////////////////////////////////////////////////////////////////////////
1585 // Implementation of the ConnectionError class
1586 /////////////////////////////////////////////////////////////////////////
1587 
1588 // Generates error msgs like ``<prefix>: ["<context>", ]<reason> (#<reason>)``
formatMcbpExceptionMsg(const std::string& prefix, cb::mcbp::Status reason, const std::string& context = �)1589 static std::string formatMcbpExceptionMsg(const std::string& prefix,
1590                                           cb::mcbp::Status reason,
1591                                           const std::string& context = "") {
1592     // Format the error message
1593     std::string errormessage(prefix);
1594     errormessage.append(": ");
1595 
1596     if (!context.empty()) {
1597         errormessage.append("'");
1598         errormessage.append(context);
1599         errormessage.append("', ");
1600     }
1601 
1602     errormessage.append(to_string(reason));
1603     errormessage.append(" (");
1604     errormessage.append(std::to_string(uint16_t(reason)));
1605     errormessage.append(")");
1606     return errormessage;
1607 }
1608 
formatMcbpExceptionMsg(const std::string& prefix, const BinprotResponse& response)1609 static std::string formatMcbpExceptionMsg(const std::string& prefix,
1610                                           const BinprotResponse& response) {
1611     std::string context;
1612     // If the response was not a success and the datatype is json then there's
1613     // probably a JSON error context that's been included with the response body
1614     if (mcbp::datatype::is_json(response.getDatatype()) &&
1615         !response.isSuccess()) {
1616         nlohmann::json json;
1617         try {
1618             auto json = nlohmann::json::parse(response.getDataString());
1619             if (json.type() == nlohmann::json::value_t::object) {
1620                 auto error = json.find("error");
1621                 if (error != json.end()) {
1622                     auto ctx = error->find("context");
1623                     if (ctx != error->end() &&
1624                         ctx->type() == nlohmann::json::value_t::string) {
1625                         context = ctx->get<std::string>();
1626                     }
1627                 }
1628             }
1629         } catch (const nlohmann::json::exception&) {
1630         }
1631     }
1632     return formatMcbpExceptionMsg(prefix, response.getStatus(), context);
1633 }
1634 
ConnectionError(const std::string& prefix, cb::mcbp::Status reason)1635 ConnectionError::ConnectionError(const std::string& prefix,
1636                                  cb::mcbp::Status reason)
1637     : std::runtime_error(formatMcbpExceptionMsg(prefix, reason).c_str()),
1638       reason(reason) {
1639 }
1640 
ConnectionError(const std::string& prefix, const BinprotResponse& response)1641 ConnectionError::ConnectionError(const std::string& prefix,
1642                                  const BinprotResponse& response)
1643     : std::runtime_error(formatMcbpExceptionMsg(prefix, response).c_str()),
1644       reason(response.getStatus()),
1645       payload(response.getDataString()) {
1646 }
1647 
getErrorContext() const1648 std::string ConnectionError::getErrorContext() const {
1649     const auto decoded = nlohmann::json::parse(payload);
1650     return decoded["error"]["context"];
1651 }
1652 
getErrorJsonContext() const1653 nlohmann::json ConnectionError::getErrorJsonContext() const {
1654     return nlohmann::json::parse(payload);
1655 }
1656