Mercurial > jhg
view src/org/tmatesoft/hg/internal/remote/SshConnector.java @ 702:992fa84e7885
Reference ssh-dependant class by the fq name to avoid runtime dependency from the ssh library
author | Artem Tikhomirov <tikhomirov.artem@gmail.com> |
---|---|
date | Thu, 08 Aug 2013 21:32:22 +0200 |
parents | a483b2b68a2e |
children |
line wrap: on
line source
/* * Copyright (c) 2013 TMate Software Ltd * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; version 2 of the License. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * For information on how to redistribute this software under * the terms of a license other than GNU General Public License * contact TMate Software at support@hg4j.com */ package org.tmatesoft.hg.internal.remote; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.Closeable; import java.io.EOFException; import java.io.FilterInputStream; import java.io.FilterOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import java.io.SequenceInputStream; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; import org.tmatesoft.hg.auth.HgAuthFailedException; import org.tmatesoft.hg.auth.HgAuthenticator; import org.tmatesoft.hg.core.HgRemoteConnectionException; import org.tmatesoft.hg.core.Nodeid; import org.tmatesoft.hg.core.SessionContext; import org.tmatesoft.hg.repo.HgRemoteRepository.Range; import org.tmatesoft.hg.repo.HgRemoteRepository.RemoteDescriptor; import org.tmatesoft.hg.repo.HgRuntimeException; import org.tmatesoft.hg.util.LogFacility.Severity; import com.trilead.ssh2.Connection; import com.trilead.ssh2.Session; import com.trilead.ssh2.StreamGobbler; /** * Remote repository via SSH * * @author Artem Tikhomirov * @author TMate Software Ltd. */ public class SshConnector extends ConnectorBase { private RemoteDescriptor rd; private SessionContext sessionCtx; private Connection conn; private Session session; private int sessionUse; private StreamGobbler remoteErr, remoteOut; private OutputStream remoteIn; public void init(RemoteDescriptor remote, SessionContext sessionContext, Object globalConfig) throws HgRuntimeException { rd = remote; sessionCtx = sessionContext; setURI(remote.getURI()); } public void connect() throws HgAuthFailedException, HgRemoteConnectionException, HgRuntimeException { try { conn = new Connection(uri.getHost(), uri.getPort() == -1 ? 22 : uri.getPort()); conn.connect(); authenticateClient(); } catch (IOException ex) { throw new HgRemoteConnectionException("Failed to establish connection").setServerInfo(getServerLocation()); } } private void authenticateClient() throws HgAuthFailedException { SshAuthMethod m = new SshAuthMethod(conn); if (uri.getUserInfo() != null) { try { m.tryWithUserInfo(uri.getUserInfo()); return; } catch (HgAuthFailedException ex) { // FALL-THROUGH to try with Authenticator } } HgAuthenticator auth = sessionCtx.getAuthenticator(rd); auth.authenticate(rd, m); } public void disconnect() throws HgRemoteConnectionException { if (session != null) { doSessionClose(); } if (conn != null) { conn.close(); conn = null; } } public void sessionBegin() throws HgRemoteConnectionException { if (sessionUse > 0) { assert session != null; sessionUse++; return; } try { session = conn.openSession(); final String path = uri.getPath(); session.execCommand(String.format("hg -R %s serve --stdio", path.charAt(0) == '/' ? path.substring(1) : path)); remoteErr = new StreamGobbler(session.getStderr()); remoteOut = new StreamGobbler(session.getStdout()); remoteIn = session.getStdin(); sessionUse = 1; } catch (IOException ex) { throw new HgRemoteConnectionException("Failed to create ssh session", ex); } } public void sessionEnd() throws HgRemoteConnectionException { assert sessionUse > 0; assert session != null; if (sessionUse > 1) { sessionUse--; return; } doSessionClose(); } public String getCapabilities() throws HgRemoteConnectionException { try { consume(remoteOut); consume(remoteErr); remoteIn.write(CMD_HELLO.getBytes()); remoteIn.write('\n'); remoteIn.write(CMD_CAPABILITIES.getBytes()); // see http connector for details remoteIn.write('\n'); remoteIn.write(CMD_HEADS.getBytes()); remoteIn.write('\n'); checkError(); int responseLen = readResponseLength(); checkError(); FilterStream s = new FilterStream(remoteOut, responseLen); BufferedReader r = new BufferedReader(new InputStreamReader(s)); String line; while ((line = r.readLine()) != null) { if (line.startsWith(CMD_CAPABILITIES) && line.length() > (CMD_CAPABILITIES.length()+1)) { line = line.substring(CMD_CAPABILITIES.length()); if (line.charAt(0) == ':') { return line.substring(CMD_CAPABILITIES.length() + 1); } } } r.close(); consume(remoteOut); checkError(); return new String(); } catch (IOException ex) { throw new HgRemoteConnectionException("Failed to initiate dialog with server", ex).setRemoteCommand(CMD_HELLO).setServerInfo(getServerLocation()); } } public InputStream heads() throws HgRemoteConnectionException { return executeCommand("heads", Collections.<Parameter>emptyList(), true); } public InputStream between(Collection<Range> ranges) throws HgRemoteConnectionException { StringBuilder sb = new StringBuilder(ranges.size() * 82); for (Range r : ranges) { r.append(sb).append(' '); } if (!ranges.isEmpty()) { sb.setLength(sb.length() - 1); } return executeCommand("between", Collections.singletonList(new Parameter("pairs", sb.toString())), true); } public InputStream branches(List<Nodeid> nodes) throws HgRemoteConnectionException { String l = join(nodes, ' '); return executeCommand("branches", Collections.singletonList(new Parameter("nodes", l)), true); } public InputStream changegroup(List<Nodeid> roots) throws HgRemoteConnectionException, HgRuntimeException { String l = join(roots, ' '); InputStream cg = executeCommand("changegroup", Collections.singletonList(new Parameter("roots", l)), false); InputStream prefix = new ByteArrayInputStream("HG10UN".getBytes()); return new SequenceInputStream(prefix, cg); } public OutputStream unbundle(long outputLen, List<Nodeid> remoteHeads) throws HgRemoteConnectionException, HgRuntimeException { String l = join(remoteHeads, ' '); try { consume(remoteOut); consume(remoteErr); remoteIn.write(CMD_UNBUNDLE.getBytes()); remoteIn.write('\n'); writeParameters(Collections.singletonList(new Parameter("heads", l))); checkError(); return new FilterOutputStream(remoteIn) { @Override public void close() throws IOException { out.flush(); @SuppressWarnings("unused") int responseLen = readResponseLength(); checkError(); // XXX perhaps, need to return responseLen to caller? } }; } catch (IOException ex) { throw new HgRemoteConnectionException("Communication failure", ex).setRemoteCommand(CMD_UNBUNDLE).setServerInfo(getServerLocation()); } } public InputStream pushkey(String opName, String namespace, String key, String oldValue, String newValue) throws HgRemoteConnectionException, HgRuntimeException { ArrayList<Parameter> p = new ArrayList<Parameter>(); p.add(new Parameter("namespace", namespace)); p.add(new Parameter("key", key)); p.add(new Parameter("old", oldValue)); p.add(new Parameter("new", newValue)); return executeCommand("pushkey", p, true); } public InputStream listkeys(String namespace, String actionName) throws HgRemoteConnectionException, HgRuntimeException { return executeCommand("listkeys", Collections.singletonList(new Parameter("namespace", namespace)), true); } private InputStream executeCommand(String cmd, List<Parameter> parameters, boolean expectResponseLength) throws HgRemoteConnectionException { try { consume(remoteOut); consume(remoteErr); remoteIn.write(cmd.getBytes()); remoteIn.write('\n'); writeParameters(parameters); checkError(); if (expectResponseLength) { int responseLen = readResponseLength(); checkError(); return new FilterStream(remoteOut, responseLen); } else { return new FilterStream(remoteOut, Integer.MAX_VALUE); } } catch (IOException ex) { throw new HgRemoteConnectionException("Communication failure", ex).setRemoteCommand(cmd).setServerInfo(getServerLocation()); } } private void writeParameters(List<Parameter> parameters) throws IOException { for (Parameter p : parameters) { remoteIn.write(p.name().getBytes()); remoteIn.write(' '); remoteIn.write(String.valueOf(p.size()).getBytes()); remoteIn.write('\n'); remoteIn.write(p.data()); remoteIn.write('\n'); } } private void consume(InputStream is) throws IOException { while (is.available() > 0) { is.read(); } } private void checkError() throws IOException { if (remoteErr.available() > 0) { StringBuilder sb = new StringBuilder(); int c; while ((c = remoteErr.read()) != -1) { sb.append((char)c); } throw new IOException(sb.toString()); } } private int readResponseLength() throws IOException { int c; StringBuilder sb = new StringBuilder(); while ((c = remoteOut.read()) != -1) { if (c == '\n') { break; } sb.append((char) c); } if (c == -1) { throw new EOFException(); } try { return Integer.parseInt(sb.toString()); } catch (NumberFormatException ex) { throw new IOException(String.format("Expected response length instead of %s", sb)); } } private void doSessionClose() { if (session != null) { closeQuietly(remoteErr); closeQuietly(remoteOut); remoteErr = remoteOut = null; closeQuietly(remoteIn); remoteIn = null; session.close(); session = null; } sessionUse = 0; } private void closeQuietly(Closeable c) { try { if (c != null) { c.close(); } } catch (IOException ex) { sessionCtx.getLog().dump(getClass(), Severity.Warn, ex, null); } } private static String join(List<Nodeid> values, char sep) { StringBuilder sb = new StringBuilder(values.size() * 41); for (Nodeid n : values) { sb.append(n.toString()); sb.append(sep); } if (!values.isEmpty()) { // strip last space sb.setLength(sb.length() - 1); } return sb.toString(); } private static final class Parameter { private final String name; private final byte[] data; public Parameter(String paramName, String paramValue) { assert paramName != null; assert paramValue != null; name = paramName; data = paramValue.getBytes(); } public String name() { return name; } public int size() { return data.length; } public byte[] data() { return data; } } private static final class FilterStream extends FilterInputStream { private int length; public FilterStream(InputStream is, int initialLength) { super(is); length = initialLength; } @Override public int available() throws IOException { return Math.min(super.available(), length); } @Override public int read() throws IOException { if (length == 0) { return -1; } int r = super.read(); if (r >= 0) { length--; } return r; } @Override public int read(byte[] b, int off, int len) throws IOException { if (length == 0) { return -1; } int r = super.read(b, off, Math.min(len, length)); if (r >= 0) { assert r <= length; length -= r; } return r; } @Override public void close() throws IOException { length = 0; // INTENTIONALLY DOES NOT CLOSE THE STREAM } } }