ESPHome  2024.11.1
ota_esphome.cpp
Go to the documentation of this file.
1 #include "ota_esphome.h"
2 #ifdef USE_OTA
12 #include "esphome/core/hal.h"
13 #include "esphome/core/log.h"
14 #include "esphome/core/util.h"
15 
16 #include <cerrno>
17 #include <cstdio>
18 
19 namespace esphome {
20 
21 static const char *const TAG = "esphome.ota";
22 static constexpr u_int16_t OTA_BLOCK_SIZE = 8192;
23 
25 #ifdef USE_OTA_STATE_CALLBACK
27 #endif
28 
29  server_ = socket::socket_ip(SOCK_STREAM, 0);
30  if (server_ == nullptr) {
31  ESP_LOGW(TAG, "Could not create socket");
32  this->mark_failed();
33  return;
34  }
35  int enable = 1;
36  int err = server_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
37  if (err != 0) {
38  ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err);
39  // we can still continue
40  }
41  err = server_->setblocking(false);
42  if (err != 0) {
43  ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err);
44  this->mark_failed();
45  return;
46  }
47 
48  struct sockaddr_storage server;
49 
50  socklen_t sl = socket::set_sockaddr_any((struct sockaddr *) &server, sizeof(server), this->port_);
51  if (sl == 0) {
52  ESP_LOGW(TAG, "Socket unable to set sockaddr: errno %d", errno);
53  this->mark_failed();
54  return;
55  }
56 
57  err = server_->bind((struct sockaddr *) &server, sizeof(server));
58  if (err != 0) {
59  ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno);
60  this->mark_failed();
61  return;
62  }
63 
64  err = server_->listen(4);
65  if (err != 0) {
66  ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
67  this->mark_failed();
68  return;
69  }
70 }
71 
73  ESP_LOGCONFIG(TAG, "Over-The-Air updates:");
74  ESP_LOGCONFIG(TAG, " Address: %s:%u", network::get_use_address().c_str(), this->port_);
75  ESP_LOGCONFIG(TAG, " Version: %d", USE_OTA_VERSION);
76 #ifdef USE_OTA_PASSWORD
77  if (!this->password_.empty()) {
78  ESP_LOGCONFIG(TAG, " Password configured");
79  }
80 #endif
81 }
82 
84 
85 static const uint8_t FEATURE_SUPPORTS_COMPRESSION = 0x01;
86 
89  bool update_started = false;
90  size_t total = 0;
91  uint32_t last_progress = 0;
92  uint8_t buf[1024];
93  char *sbuf = reinterpret_cast<char *>(buf);
94  size_t ota_size;
95  uint8_t ota_features;
96  std::unique_ptr<ota::OTABackend> backend;
97  (void) ota_features;
98 #if USE_OTA_VERSION == 2
99  size_t size_acknowledged = 0;
100 #endif
101 
102  if (client_ == nullptr) {
103  struct sockaddr_storage source_addr;
104  socklen_t addr_len = sizeof(source_addr);
105  client_ = server_->accept((struct sockaddr *) &source_addr, &addr_len);
106  }
107  if (client_ == nullptr)
108  return;
109 
110  int enable = 1;
111  int err = client_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
112  if (err != 0) {
113  ESP_LOGW(TAG, "Socket could not enable TCP nodelay, errno %d", errno);
114  return;
115  }
116 
117  ESP_LOGD(TAG, "Starting update from %s...", this->client_->getpeername().c_str());
118  this->status_set_warning();
119 #ifdef USE_OTA_STATE_CALLBACK
120  this->state_callback_.call(ota::OTA_STARTED, 0.0f, 0);
121 #endif
122 
123  if (!this->readall_(buf, 5)) {
124  ESP_LOGW(TAG, "Reading magic bytes failed");
125  goto error; // NOLINT(cppcoreguidelines-avoid-goto)
126  }
127  // 0x6C, 0x26, 0xF7, 0x5C, 0x45
128  if (buf[0] != 0x6C || buf[1] != 0x26 || buf[2] != 0xF7 || buf[3] != 0x5C || buf[4] != 0x45) {
129  ESP_LOGW(TAG, "Magic bytes do not match! 0x%02X-0x%02X-0x%02X-0x%02X-0x%02X", buf[0], buf[1], buf[2], buf[3],
130  buf[4]);
131  error_code = ota::OTA_RESPONSE_ERROR_MAGIC;
132  goto error; // NOLINT(cppcoreguidelines-avoid-goto)
133  }
134 
135  // Send OK and version - 2 bytes
136  buf[0] = ota::OTA_RESPONSE_OK;
137  buf[1] = USE_OTA_VERSION;
138  this->writeall_(buf, 2);
139 
140  backend = ota::make_ota_backend();
141 
142  // Read features - 1 byte
143  if (!this->readall_(buf, 1)) {
144  ESP_LOGW(TAG, "Reading features failed");
145  goto error; // NOLINT(cppcoreguidelines-avoid-goto)
146  }
147  ota_features = buf[0]; // NOLINT
148  ESP_LOGV(TAG, "Features: 0x%02X", ota_features);
149 
150  // Acknowledge header - 1 byte
152  if ((ota_features & FEATURE_SUPPORTS_COMPRESSION) != 0 && backend->supports_compression()) {
154  }
155 
156  this->writeall_(buf, 1);
157 
158 #ifdef USE_OTA_PASSWORD
159  if (!this->password_.empty()) {
161  this->writeall_(buf, 1);
162  md5::MD5Digest md5{};
163  md5.init();
164  sprintf(sbuf, "%08" PRIx32, random_uint32());
165  md5.add(sbuf, 8);
166  md5.calculate();
167  md5.get_hex(sbuf);
168  ESP_LOGV(TAG, "Auth: Nonce is %s", sbuf);
169 
170  // Send nonce, 32 bytes hex MD5
171  if (!this->writeall_(reinterpret_cast<uint8_t *>(sbuf), 32)) {
172  ESP_LOGW(TAG, "Auth: Writing nonce failed");
173  goto error; // NOLINT(cppcoreguidelines-avoid-goto)
174  }
175 
176  // prepare challenge
177  md5.init();
178  md5.add(this->password_.c_str(), this->password_.length());
179  // add nonce
180  md5.add(sbuf, 32);
181 
182  // Receive cnonce, 32 bytes hex MD5
183  if (!this->readall_(buf, 32)) {
184  ESP_LOGW(TAG, "Auth: Reading cnonce failed");
185  goto error; // NOLINT(cppcoreguidelines-avoid-goto)
186  }
187  sbuf[32] = '\0';
188  ESP_LOGV(TAG, "Auth: CNonce is %s", sbuf);
189  // add cnonce
190  md5.add(sbuf, 32);
191 
192  // calculate result
193  md5.calculate();
194  md5.get_hex(sbuf);
195  ESP_LOGV(TAG, "Auth: Result is %s", sbuf);
196 
197  // Receive result, 32 bytes hex MD5
198  if (!this->readall_(buf + 64, 32)) {
199  ESP_LOGW(TAG, "Auth: Reading response failed");
200  goto error; // NOLINT(cppcoreguidelines-avoid-goto)
201  }
202  sbuf[64 + 32] = '\0';
203  ESP_LOGV(TAG, "Auth: Response is %s", sbuf + 64);
204 
205  bool matches = true;
206  for (uint8_t i = 0; i < 32; i++)
207  matches = matches && buf[i] == buf[64 + i];
208 
209  if (!matches) {
210  ESP_LOGW(TAG, "Auth failed! Passwords do not match");
212  goto error; // NOLINT(cppcoreguidelines-avoid-goto)
213  }
214  }
215 #endif // USE_OTA_PASSWORD
216 
217  // Acknowledge auth OK - 1 byte
218  buf[0] = ota::OTA_RESPONSE_AUTH_OK;
219  this->writeall_(buf, 1);
220 
221  // Read size, 4 bytes MSB first
222  if (!this->readall_(buf, 4)) {
223  ESP_LOGW(TAG, "Reading size failed");
224  goto error; // NOLINT(cppcoreguidelines-avoid-goto)
225  }
226  ota_size = 0;
227  for (uint8_t i = 0; i < 4; i++) {
228  ota_size <<= 8;
229  ota_size |= buf[i];
230  }
231  ESP_LOGV(TAG, "Size is %u bytes", ota_size);
232 
233  error_code = backend->begin(ota_size);
234  if (error_code != ota::OTA_RESPONSE_OK)
235  goto error; // NOLINT(cppcoreguidelines-avoid-goto)
236  update_started = true;
237 
238  // Acknowledge prepare OK - 1 byte
240  this->writeall_(buf, 1);
241 
242  // Read binary MD5, 32 bytes
243  if (!this->readall_(buf, 32)) {
244  ESP_LOGW(TAG, "Reading binary MD5 checksum failed");
245  goto error; // NOLINT(cppcoreguidelines-avoid-goto)
246  }
247  sbuf[32] = '\0';
248  ESP_LOGV(TAG, "Update: Binary MD5 is %s", sbuf);
249  backend->set_update_md5(sbuf);
250 
251  // Acknowledge MD5 OK - 1 byte
253  this->writeall_(buf, 1);
254 
255  while (total < ota_size) {
256  // TODO: timeout check
257  size_t requested = std::min(sizeof(buf), ota_size - total);
258  ssize_t read = this->client_->read(buf, requested);
259  if (read == -1) {
260  if (errno == EAGAIN || errno == EWOULDBLOCK) {
261  App.feed_wdt();
262  delay(1);
263  continue;
264  }
265  ESP_LOGW(TAG, "Error receiving data for update, errno %d", errno);
266  goto error; // NOLINT(cppcoreguidelines-avoid-goto)
267  } else if (read == 0) {
268  // $ man recv
269  // "When a stream socket peer has performed an orderly shutdown, the return value will
270  // be 0 (the traditional "end-of-file" return)."
271  ESP_LOGW(TAG, "Remote end closed connection");
272  goto error; // NOLINT(cppcoreguidelines-avoid-goto)
273  }
274 
275  error_code = backend->write(buf, read);
276  if (error_code != ota::OTA_RESPONSE_OK) {
277  ESP_LOGW(TAG, "Error writing binary data to flash!, error_code: %d", error_code);
278  goto error; // NOLINT(cppcoreguidelines-avoid-goto)
279  }
280  total += read;
281 #if USE_OTA_VERSION == 2
282  while (size_acknowledged + OTA_BLOCK_SIZE <= total || (total == ota_size && size_acknowledged < ota_size)) {
284  this->writeall_(buf, 1);
285  size_acknowledged += OTA_BLOCK_SIZE;
286  }
287 #endif
288 
289  uint32_t now = millis();
290  if (now - last_progress > 1000) {
291  last_progress = now;
292  float percentage = (total * 100.0f) / ota_size;
293  ESP_LOGD(TAG, "Progress: %0.1f%%", percentage);
294 #ifdef USE_OTA_STATE_CALLBACK
295  this->state_callback_.call(ota::OTA_IN_PROGRESS, percentage, 0);
296 #endif
297  // feed watchdog and give other tasks a chance to run
298  App.feed_wdt();
299  yield();
300  }
301  }
302 
303  // Acknowledge receive OK - 1 byte
305  this->writeall_(buf, 1);
306 
307  error_code = backend->end();
308  if (error_code != ota::OTA_RESPONSE_OK) {
309  ESP_LOGW(TAG, "Error ending update! error_code: %d", error_code);
310  goto error; // NOLINT(cppcoreguidelines-avoid-goto)
311  }
312 
313  // Acknowledge Update end OK - 1 byte
315  this->writeall_(buf, 1);
316 
317  // Read ACK
318  if (!this->readall_(buf, 1) || buf[0] != ota::OTA_RESPONSE_OK) {
319  ESP_LOGW(TAG, "Reading back acknowledgement failed");
320  // do not go to error, this is not fatal
321  }
322 
323  this->client_->close();
324  this->client_ = nullptr;
325  delay(10);
326  ESP_LOGI(TAG, "Update complete");
327  this->status_clear_warning();
328 #ifdef USE_OTA_STATE_CALLBACK
329  this->state_callback_.call(ota::OTA_COMPLETED, 100.0f, 0);
330 #endif
331  delay(100); // NOLINT
332  App.safe_reboot();
333 
334 error:
335  buf[0] = static_cast<uint8_t>(error_code);
336  this->writeall_(buf, 1);
337  this->client_->close();
338  this->client_ = nullptr;
339 
340  if (backend != nullptr && update_started) {
341  backend->abort();
342  }
343 
344  this->status_momentary_error("onerror", 5000);
345 #ifdef USE_OTA_STATE_CALLBACK
346  this->state_callback_.call(ota::OTA_ERROR, 0.0f, static_cast<uint8_t>(error_code));
347 #endif
348 }
349 
350 bool ESPHomeOTAComponent::readall_(uint8_t *buf, size_t len) {
351  uint32_t start = millis();
352  uint32_t at = 0;
353  while (len - at > 0) {
354  uint32_t now = millis();
355  if (now - start > 1000) {
356  ESP_LOGW(TAG, "Timed out reading %d bytes of data", len);
357  return false;
358  }
359 
360  ssize_t read = this->client_->read(buf + at, len - at);
361  if (read == -1) {
362  if (errno == EAGAIN || errno == EWOULDBLOCK) {
363  App.feed_wdt();
364  delay(1);
365  continue;
366  }
367  ESP_LOGW(TAG, "Failed to read %d bytes of data, errno %d", len, errno);
368  return false;
369  } else if (read == 0) {
370  ESP_LOGW(TAG, "Remote closed connection");
371  return false;
372  } else {
373  at += read;
374  }
375  App.feed_wdt();
376  delay(1);
377  }
378 
379  return true;
380 }
381 bool ESPHomeOTAComponent::writeall_(const uint8_t *buf, size_t len) {
382  uint32_t start = millis();
383  uint32_t at = 0;
384  while (len - at > 0) {
385  uint32_t now = millis();
386  if (now - start > 1000) {
387  ESP_LOGW(TAG, "Timed out writing %d bytes of data", len);
388  return false;
389  }
390 
391  ssize_t written = this->client_->write(buf + at, len - at);
392  if (written == -1) {
393  if (errno == EAGAIN || errno == EWOULDBLOCK) {
394  App.feed_wdt();
395  delay(1);
396  continue;
397  }
398  ESP_LOGW(TAG, "Failed to write %d bytes of data, errno %d", len, errno);
399  return false;
400  } else {
401  at += written;
402  }
403  App.feed_wdt();
404  delay(1);
405  }
406  return true;
407 }
408 
410 uint16_t ESPHomeOTAComponent::get_port() const { return this->port_; }
411 void ESPHomeOTAComponent::set_port(uint16_t port) { this->port_ = port; }
412 } // namespace esphome
413 #endif
void init()
Initialize a new MD5 digest computation.
Definition: md5.cpp:11
std::unique_ptr< Socket > socket_ip(int type, int protocol)
Create a socket in the newest available IP domain (IPv6 or IPv4) of the given type and protocol...
Definition: socket.cpp:13
const float AFTER_WIFI
For components that should be initialized after WiFi is connected.
Definition: component.cpp:26
socklen_t set_sockaddr_any(struct sockaddr *addr, socklen_t addrlen, uint16_t port)
Set a sockaddr to the any address and specified port for the IP version used by socket_ip().
Definition: socket.cpp:51
std::string get_use_address()
Get the active network hostname.
Definition: util.cpp:52
void status_set_warning(const char *message="unspecified")
Definition: component.cpp:151
uint32_t random_uint32()
Return a random 32-bit unsigned integer.
Definition: helpers.cpp:193
uint32_t socklen_t
Definition: headers.h:97
std::unique_ptr< socket::Socket > client_
Definition: ota_esphome.h:41
void register_ota_platform(OTAComponent *ota_caller)
Definition: ota_backend.cpp:16
uint32_t IRAM_ATTR HOT millis()
Definition: core.cpp:25
void status_momentary_error(const std::string &name, uint32_t length=5000)
Definition: component.cpp:182
CallbackManager< void(ota::OTAState, float, uint8_t)> state_callback_
Definition: ota_backend.h:70
void status_clear_warning()
Definition: component.cpp:166
Application App
Global storage of Application pointer - only one Application can exist.
std::unique_ptr< ota::OTABackend > make_ota_backend()
bool writeall_(const uint8_t *buf, size_t len)
std::string size_t len
Definition: helpers.h:293
void IRAM_ATTR HOT yield()
Definition: core.cpp:24
float get_setup_priority() const override
std::unique_ptr< socket::Socket > server_
Definition: ota_esphome.h:40
virtual void mark_failed()
Mark this component as failed.
Definition: component.cpp:118
Implementation of SPI Controller mode.
Definition: a01nyub.cpp:7
void set_port(uint16_t port)
Manually set the port OTA should listen on.
bool readall_(uint8_t *buf, size_t len)
void IRAM_ATTR HOT delay(uint32_t ms)
Definition: core.cpp:26