# This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this file, # You can obtain one at http://mozilla.org/MPL/2.0/. import abc import errno import os import platform import socket import threading import time import traceback import urlparse import mozprocess __all__ = ["SeleniumServer", "ChromeDriverServer", "GeckoDriverServer", "WebDriverServer"] class WebDriverServer(object): __metaclass__ = abc.ABCMeta default_base_path = "/" _used_ports = set() def __init__(self, logger, binary, host="127.0.0.1", port=None, base_path="", env=None): self.logger = logger self.binary = binary self.host = host if base_path == "": self.base_path = self.default_base_path else: self.base_path = base_path self.env = os.environ.copy() if env is None else env self._port = port self._cmd = None self._proc = None @abc.abstractmethod def make_command(self): """Returns the full command for starting the server process as a list.""" def start(self, block=True): try: self._run(block) except KeyboardInterrupt: self.stop() def _run(self, block): self._cmd = self.make_command() self._proc = mozprocess.ProcessHandler( self._cmd, processOutputLine=self.on_output, env=self.env, storeOutput=False) try: self._proc.run() except OSError as e: if e.errno == errno.ENOENT: raise IOError( "WebDriver HTTP server executable not found: %s" % self.binary) raise self.logger.debug( "Waiting for server to become accessible: %s" % self.url) try: wait_for_service((self.host, self.port)) except: self.logger.error( "WebDriver HTTP server was not accessible " "within the timeout:\n%s" % traceback.format_exc()) if self._proc.poll(): self.logger.error("Webdriver server process exited with code %i" % self._proc.returncode) raise if block: self._proc.wait() def stop(self): if self.is_alive: return self._proc.kill() return not self.is_alive @property def is_alive(self): return (self._proc is not None and self._proc.proc is not None and self._proc.poll() is None) def on_output(self, line): self.logger.process_output(self.pid, line.decode("utf8", "replace"), command=" ".join(self._cmd)) @property def pid(self): if self._proc is not None: return self._proc.pid @property def url(self): return "http://%s:%i%s" % (self.host, self.port, self.base_path) @property def port(self): if self._port is None: self._port = self._find_next_free_port() return self._port @staticmethod def _find_next_free_port(): port = get_free_port(4444, exclude=WebDriverServer._used_ports) WebDriverServer._used_ports.add(port) return port class SeleniumServer(WebDriverServer): default_base_path = "/wd/hub" def make_command(self): return ["java", "-jar", self.binary, "-port", str(self.port)] class ChromeDriverServer(WebDriverServer): default_base_path = "/wd/hub" def __init__(self, logger, binary="chromedriver", port=None, base_path=""): WebDriverServer.__init__( self, logger, binary, port=port, base_path=base_path) def make_command(self): return [self.binary, cmd_arg("port", str(self.port)), cmd_arg("url-base", self.base_path) if self.base_path else ""] class GeckoDriverServer(WebDriverServer): def __init__(self, logger, marionette_port=2828, binary="wires", host="127.0.0.1", port=None): env = os.environ.copy() env["RUST_BACKTRACE"] = "1" WebDriverServer.__init__(self, logger, binary, host=host, port=port, env=env) self.marionette_port = marionette_port def make_command(self): return [self.binary, "--connect-existing", "--marionette-port", str(self.marionette_port), "--host", self.host, "--port", str(self.port)] def cmd_arg(name, value=None): prefix = "-" if platform.system() == "Windows" else "--" rv = prefix + name if value is not None: rv += "=" + value return rv def get_free_port(start_port, exclude=None): """Get the first port number after start_port (inclusive) that is not currently bound. :param start_port: Integer port number at which to start testing. :param exclude: Set of port numbers to skip""" port = start_port while True: if exclude and port in exclude: port += 1 continue s = socket.socket() try: s.bind(("127.0.0.1", port)) except socket.error: port += 1 else: return port finally: s.close() def wait_for_service(addr, timeout=15): """Waits until network service given as a tuple of (host, port) becomes available or the `timeout` duration is reached, at which point ``socket.error`` is raised.""" end = time.time() + timeout while end > time.time(): so = socket.socket() try: so.connect(addr) except socket.timeout: pass except socket.error as e: if e[0] != errno.ECONNREFUSED: raise else: return True finally: so.close() time.sleep(0.5) raise socket.error("Service is unavailable: %s:%i" % addr)