1 /* -*- Mode: C++; tab-width: 4; c-basic-offset: 4; indent-tabs-mode: nil -*- */
2 /*
3  *     Copyright 2017 Couchbase, Inc.
4  *
5  *   Licensed under the Apache License, Version 2.0 (the "License");
6  *   you may not use this file except in compliance with the License.
7  *   You may obtain a copy of the License at
8  *
9  *       http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *   Unless required by applicable law or agreed to in writing, software
12  *   distributed under the License is distributed on an "AS IS" BASIS,
13  *   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *   See the License for the specific language governing permissions and
15  *   limitations under the License.
16  */
17 #include "client_connection_map.h"
18 
19 /////////////////////////////////////////////////////////////////////////
20 // Implementation of the ConnectionMap class
21 /////////////////////////////////////////////////////////////////////////
getConnection(bool ssl, sa_family_t family, in_port_t port)22 MemcachedConnection& ConnectionMap::getConnection(bool ssl,
23                                                   sa_family_t family,
24                                                   in_port_t port) {
25     for (auto& conn : connections) {
26         if (conn->isSsl() == ssl && conn->getFamily() == family &&
27             (port == 0 || conn->getPort() == port)) {
28             return *conn.get();
29         }
30     }
31 
32     throw std::runtime_error("No connection matching the request");
33 }
34 
contains(bool ssl, sa_family_t family)35 bool ConnectionMap::contains(bool ssl, sa_family_t family) {
36     try {
37         (void)getConnection(ssl, family, 0);
38         return true;
39     } catch (const std::runtime_error&) {
40         return false;
41     }
42 }
43 
initialize(cJSON* ports)44 void ConnectionMap::initialize(cJSON* ports) {
45     invalidate();
46     cJSON* array = cJSON_GetObjectItem(ports, "ports");
47     if (array == nullptr) {
48         std::string msg("ports not found in portnumber file: ");
49         msg.append(to_string(ports, false));
50         throw std::runtime_error(msg);
51     }
52 
53     auto numEntries = cJSON_GetArraySize(array);
54     sa_family_t family;
55     for (int ii = 0; ii < numEntries; ++ii) {
56         auto obj = cJSON_GetArrayItem(array, ii);
57         auto fam = cJSON_GetObjectItem(obj, "family");
58         if (strcmp(fam->valuestring, "AF_INET") == 0) {
59             family = AF_INET;
60         } else if (strcmp(fam->valuestring, "AF_INET6") == 0) {
61             family = AF_INET6;
62         } else {
63             std::string msg("Unsupported network family: ");
64             msg.append(to_string(obj, false));
65             throw std::runtime_error(msg);
66         }
67 
68         auto ssl = cJSON_GetObjectItem(obj, "ssl");
69         if (ssl == nullptr) {
70             std::string msg("ssl missing for entry: ");
71             msg.append(to_string(obj, false));
72             throw std::runtime_error(msg);
73         }
74 
75         auto port = cJSON_GetObjectItem(obj, "port");
76         if (port == nullptr) {
77             std::string msg("port missing for entry: ");
78             msg.append(to_string(obj, false));
79             throw std::runtime_error(msg);
80         }
81 
82         auto protocol = cJSON_GetObjectItem(obj, "protocol");
83         if (protocol == nullptr) {
84             std::string msg("protocol missing for entry: ");
85             msg.append(to_string(obj, false));
86             throw std::runtime_error(msg);
87         }
88 
89         auto portval = static_cast<in_port_t>(port->valueint);
90         bool useSsl = ssl->type == cJSON_True ? true : false;
91 
92         MemcachedConnection* connection;
93         if (strcmp(protocol->valuestring, "memcached") != 0) {
94             throw std::logic_error(
95                     "ConnectionMap::initialize: Invalid value passed for "
96                     "protocol: " +
97                     std::string(protocol->valuestring));
98         }
99 
100         connection = new MemcachedConnection("", portval, family, useSsl);
101         connections.push_back(std::unique_ptr<MemcachedConnection>{connection});
102     }
103 }
104 
invalidate()105 void ConnectionMap::invalidate() {
106     connections.resize(0);
107 }
108