summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--tests/unit/test_ssh.py19
-rw-r--r--yardstick/ssh.py28
2 files changed, 37 insertions, 10 deletions
diff --git a/tests/unit/test_ssh.py b/tests/unit/test_ssh.py
index a27052462..1e021a051 100644
--- a/tests/unit/test_ssh.py
+++ b/tests/unit/test_ssh.py
@@ -18,6 +18,8 @@
import os
import unittest
+from cStringIO import StringIO
+
import mock
from yardstick import ssh
@@ -275,6 +277,23 @@ class SSHRunTestCase(unittest.TestCase):
self.assertEqual(send_calls, self.fake_session.send.mock_calls)
@mock.patch("yardstick.ssh.select")
+ def test_run_stdin_keep_open(self, mock_select):
+ """Test run method with stdin.
+
+ Third send call was called with "e2" because only 3 bytes was sent
+ by second call. So remainig 2 bytes of "line2" was sent by third call.
+ """
+ mock_select.select.return_value = ([], [], [])
+ self.fake_session.exit_status_ready.side_effect = [0, 0, 0, True]
+ self.fake_session.send_ready.return_value = True
+ self.fake_session.send.side_effect = len
+ fake_stdin = StringIO("line1\nline2\n")
+ self.test_client.run("cmd", stdin=fake_stdin, keep_stdin_open=True)
+ call = mock.call
+ send_calls = [call("line1\nline2\n")]
+ self.assertEqual(send_calls, self.fake_session.send.mock_calls)
+
+ @mock.patch("yardstick.ssh.select")
def test_run_select_error(self, mock_select):
self.fake_session.exit_status_ready.return_value = False
mock_select.select.return_value = ([], [], [True])
diff --git a/yardstick/ssh.py b/yardstick/ssh.py
index 8b71fe606..e0e2f83ee 100644
--- a/yardstick/ssh.py
+++ b/yardstick/ssh.py
@@ -140,10 +140,12 @@ class SSH(object):
self._client = False
def run(self, cmd, stdin=None, stdout=None, stderr=None,
- raise_on_error=True, timeout=3600):
+ raise_on_error=True, timeout=3600,
+ keep_stdin_open=False):
"""Execute specified command on the server.
:param cmd: Command to be executed.
+ :type cmd: str
:param stdin: Open file or string to pass to stdin.
:param stdout: Open file to connect to stdout.
:param stderr: Open file to connect to stderr.
@@ -151,6 +153,8 @@ class SSH(object):
then exception will be raized if non-zero code.
:param timeout: Timeout in seconds for command execution.
Default 1 hour. No timeout if set to 0.
+ :param keep_stdin_open: don't close stdin on empty reads
+ :type keep_stdin_open: bool
"""
client = self._get_client()
@@ -160,10 +164,12 @@ class SSH(object):
return self._run(client, cmd, stdin=stdin, stdout=stdout,
stderr=stderr, raise_on_error=raise_on_error,
- timeout=timeout)
+ timeout=timeout,
+ keep_stdin_open=keep_stdin_open)
def _run(self, client, cmd, stdin=None, stdout=None, stderr=None,
- raise_on_error=True, timeout=3600):
+ raise_on_error=True, timeout=3600,
+ keep_stdin_open=False):
transport = client.get_transport()
session = transport.open_session()
@@ -203,13 +209,15 @@ class SSH(object):
if not data_to_send:
data_to_send = stdin.read(4096)
if not data_to_send:
- stdin.close()
- session.shutdown_write()
- writes = []
- continue
- sent_bytes = session.send(data_to_send)
- # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
- data_to_send = data_to_send[sent_bytes:]
+ # we may need to keep stdin open
+ if not keep_stdin_open:
+ stdin.close()
+ session.shutdown_write()
+ writes = []
+ if data_to_send:
+ sent_bytes = session.send(data_to_send)
+ # LOG.debug("sent: %s" % data_to_send[:sent_bytes])
+ data_to_send = data_to_send[sent_bytes:]
if session.exit_status_ready():
break