/
/
/
1#include "bag_recorder_backend/websocket_manager.hpp"
2#include <nlohmann/json.hpp>
3#include <chrono>
4#include <algorithm>
5
6namespace bag_recorder_backend
7{
8
9WebSocketManager::WebSocketManager()
10: total_messages_sent_(0),
11 total_messages_received_(0)
12{
13}
14
15WebSocketManager::~WebSocketManager()
16{
17 std::lock_guard<std::mutex> lock(connections_mutex_);
18 connections_.clear();
19}
20
21void WebSocketManager::add_connection(crow::websocket::connection* conn)
22{
23 std::lock_guard<std::mutex> lock(connections_mutex_);
24 connections_.insert(conn);
25}
26
27void WebSocketManager::remove_connection(crow::websocket::connection* conn)
28{
29 std::lock_guard<std::mutex> lock(connections_mutex_);
30 remove_connection_unsafe(conn);
31}
32
33size_t WebSocketManager::get_connection_count() const
34{
35 std::lock_guard<std::mutex> lock(connections_mutex_);
36 return connections_.size();
37}
38
39void WebSocketManager::broadcast_message(const WebSocketMessage & message)
40{
41 std::lock_guard<std::mutex> lock(connections_mutex_);
42
43 std::string serialized_message = serialize_message(message);
44
45 for (auto it = connections_.begin(); it != connections_.end();) {
46 try {
47 if (is_connection_valid(*it)) {
48 (*it)->send_text(serialized_message);
49 ++it;
50 } else {
51 it = connections_.erase(it);
52 }
53 } catch (const std::exception &) {
54 it = connections_.erase(it);
55 }
56 }
57
58 total_messages_sent_++;
59}
60
61void WebSocketManager::broadcast_status_update(const std::string & status_json)
62{
63 WebSocketMessage message;
64 message.type = MessageType::STATUS_UPDATE;
65 message.data = status_json;
66 message.timestamp = get_current_timestamp();
67
68 broadcast_message(message);
69}
70
71void WebSocketManager::broadcast_recording_started(const std::string & config_json)
72{
73 WebSocketMessage message;
74 message.type = MessageType::RECORDING_STARTED;
75 message.data = config_json;
76 message.timestamp = get_current_timestamp();
77
78 broadcast_message(message);
79}
80
81void WebSocketManager::broadcast_recording_stopped()
82{
83 WebSocketMessage message;
84 message.type = MessageType::RECORDING_STOPPED;
85 message.data = "{}";
86 message.timestamp = get_current_timestamp();
87
88 broadcast_message(message);
89}
90
91void WebSocketManager::broadcast_error(const std::string & error_message)
92{
93 WebSocketMessage message;
94 message.type = MessageType::ERROR_MESSAGE;
95 message.data = error_message;
96 message.timestamp = get_current_timestamp();
97
98 broadcast_message(message);
99}
100
101void WebSocketManager::broadcast_topic_list(const std::string & topics_json)
102{
103 WebSocketMessage message;
104 message.type = MessageType::TOPIC_LIST_UPDATE;
105 message.data = topics_json;
106 message.timestamp = get_current_timestamp();
107
108 broadcast_message(message);
109}
110
111void WebSocketManager::send_message(crow::websocket::connection* conn, const WebSocketMessage & message)
112{
113 if (!is_connection_valid(conn)) {
114 remove_connection(conn);
115 return;
116 }
117
118 try {
119 std::string serialized_message = serialize_message(message);
120 conn->send_text(serialized_message);
121 total_messages_sent_++;
122 } catch (const std::exception &) {
123 remove_connection(conn);
124 }
125}
126
127void WebSocketManager::send_ping(crow::websocket::connection* conn)
128{
129 WebSocketMessage message;
130 message.type = MessageType::PING;
131 message.data = "{}";
132 message.timestamp = get_current_timestamp();
133
134 send_message(conn, message);
135}
136
137void WebSocketManager::send_pong(crow::websocket::connection* conn)
138{
139 WebSocketMessage message;
140 message.type = MessageType::PONG;
141 message.data = "{}";
142 message.timestamp = get_current_timestamp();
143
144 send_message(conn, message);
145}
146
147void WebSocketManager::handle_client_message(crow::websocket::connection* conn, const std::string & message)
148{
149 total_messages_received_++;
150
151 try {
152 WebSocketMessage parsed_message = parse_client_message(message);
153
154 switch (parsed_message.type) {
155 case MessageType::PING:
156 send_pong(conn);
157 if (ping_callback_) {
158 ping_callback_(conn);
159 }
160 break;
161
162 case MessageType::PONG:
163 // Handle pong if needed
164 break;
165
166 default:
167 if (message_callback_) {
168 message_callback_(conn, parsed_message.data);
169 }
170 break;
171 }
172 } catch (const std::exception &) {
173 // Invalid message format, ignore or send error
174 broadcast_error("Invalid message format");
175 }
176}
177
178bool WebSocketManager::is_connection_valid(crow::websocket::connection* conn) const
179{
180 // Basic validity check - in a real implementation, you might want more sophisticated checks
181 return conn != nullptr;
182}
183
184void WebSocketManager::cleanup_invalid_connections()
185{
186 std::lock_guard<std::mutex> lock(connections_mutex_);
187
188 for (auto it = connections_.begin(); it != connections_.end();) {
189 if (!is_connection_valid(*it)) {
190 it = connections_.erase(it);
191 } else {
192 ++it;
193 }
194 }
195}
196
197void WebSocketManager::set_ping_callback(std::function<void(crow::websocket::connection*)> callback)
198{
199 ping_callback_ = callback;
200}
201
202void WebSocketManager::set_message_callback(std::function<void(crow::websocket::connection*, const std::string&)> callback)
203{
204 message_callback_ = callback;
205}
206
207WebSocketMessage WebSocketManager::parse_client_message(const std::string & raw_message)
208{
209 auto json_data = nlohmann::json::parse(raw_message);
210
211 WebSocketMessage message;
212 message.type = string_to_message_type(json_data["type"]);
213 message.data = json_data.value("data", "{}");
214 message.timestamp = json_data.value("timestamp", get_current_timestamp());
215
216 return message;
217}
218
219std::string WebSocketManager::serialize_message(const WebSocketMessage & message)
220{
221 nlohmann::json json_data;
222 json_data["type"] = message_type_to_string(message.type);
223 json_data["data"] = nlohmann::json::parse(message.data);
224 json_data["timestamp"] = message.timestamp;
225
226 return json_data.dump();
227}
228
229std::string WebSocketManager::get_current_timestamp()
230{
231 auto now = std::chrono::system_clock::now();
232 auto time_t = std::chrono::system_clock::to_time_t(now);
233 auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(
234 now.time_since_epoch()) % 1000;
235
236 std::stringstream ss;
237 ss << std::put_time(std::gmtime(&time_t), "%Y-%m-%dT%H:%M:%S");
238 ss << '.' << std::setfill('0') << std::setw(3) << ms.count() << 'Z';
239
240 return ss.str();
241}
242
243void WebSocketManager::remove_connection_unsafe(crow::websocket::connection* conn)
244{
245 connections_.erase(conn);
246}
247
248// Helper functions
249std::string message_type_to_string(MessageType type)
250{
251 switch (type) {
252 case MessageType::STATUS_UPDATE:
253 return "status_update";
254 case MessageType::RECORDING_STARTED:
255 return "recording_started";
256 case MessageType::RECORDING_STOPPED:
257 return "recording_stopped";
258 case MessageType::ERROR_MESSAGE:
259 return "error_message";
260 case MessageType::TOPIC_LIST_UPDATE:
261 return "topic_list_update";
262 case MessageType::PING:
263 return "ping";
264 case MessageType::PONG:
265 return "pong";
266 default:
267 return "unknown";
268 }
269}
270
271MessageType string_to_message_type(const std::string & type_str)
272{
273 if (type_str == "status_update") return MessageType::STATUS_UPDATE;
274 if (type_str == "recording_started") return MessageType::RECORDING_STARTED;
275 if (type_str == "recording_stopped") return MessageType::RECORDING_STOPPED;
276 if (type_str == "error_message") return MessageType::ERROR_MESSAGE;
277 if (type_str == "topic_list_update") return MessageType::TOPIC_LIST_UPDATE;
278 if (type_str == "ping") return MessageType::PING;
279 if (type_str == "pong") return MessageType::PONG;
280
281 return MessageType::STATUS_UPDATE; // Default fallback
282}
283
284} // namespace bag_recorder_backend