view src/org/tmatesoft/hg/internal/remote/SshConnector.java @ 708:4ffc17c0b534

Merge: tests for resolver and complex scenario. Enable commit for merged revisions. Reuse file revisions if nothing changed
author Artem Tikhomirov <tikhomirov.artem@gmail.com>
date Tue, 20 Aug 2013 18:41:34 +0200
parents a483b2b68a2e
children
line wrap: on
line source
/*
 * Copyright (c) 2013 TMate Software Ltd
 *  
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; version 2 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * For information on how to redistribute this software under
 * the terms of a license other than GNU General Public License
 * contact TMate Software at support@hg4j.com
 */
package org.tmatesoft.hg.internal.remote;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.EOFException;
import java.io.FilterInputStream;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.SequenceInputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

import org.tmatesoft.hg.auth.HgAuthFailedException;
import org.tmatesoft.hg.auth.HgAuthenticator;
import org.tmatesoft.hg.core.HgRemoteConnectionException;
import org.tmatesoft.hg.core.Nodeid;
import org.tmatesoft.hg.core.SessionContext;
import org.tmatesoft.hg.repo.HgRemoteRepository.Range;
import org.tmatesoft.hg.repo.HgRemoteRepository.RemoteDescriptor;
import org.tmatesoft.hg.repo.HgRuntimeException;
import org.tmatesoft.hg.util.LogFacility.Severity;

import com.trilead.ssh2.Connection;
import com.trilead.ssh2.Session;
import com.trilead.ssh2.StreamGobbler;

/**
 * Remote repository via SSH
 * 
 * @author Artem Tikhomirov
 * @author TMate Software Ltd.
 */
public class SshConnector extends ConnectorBase {
	private RemoteDescriptor rd;
	private SessionContext sessionCtx;
	private Connection conn;
	private Session session;
	private int sessionUse;

	private StreamGobbler remoteErr, remoteOut;
	private OutputStream remoteIn;
	
	public void init(RemoteDescriptor remote, SessionContext sessionContext, Object globalConfig) throws HgRuntimeException {
		rd = remote;
		sessionCtx = sessionContext;
		setURI(remote.getURI());
	}
	
	public void connect() throws HgAuthFailedException, HgRemoteConnectionException, HgRuntimeException {
		try {
			conn = new Connection(uri.getHost(), uri.getPort() == -1 ? 22 : uri.getPort());
			conn.connect();
			authenticateClient();
		} catch (IOException ex) {
			throw new HgRemoteConnectionException("Failed to establish connection").setServerInfo(getServerLocation());
		}
	}

	private void authenticateClient() throws HgAuthFailedException {
		SshAuthMethod m = new SshAuthMethod(conn);
		if (uri.getUserInfo() != null) {
			try {
				m.tryWithUserInfo(uri.getUserInfo());
				return;
			} catch (HgAuthFailedException ex) {
				// FALL-THROUGH to try with Authenticator
			}
		}
		HgAuthenticator auth = sessionCtx.getAuthenticator(rd);
		auth.authenticate(rd, m);
	}
	
	public void disconnect() throws HgRemoteConnectionException {
		if (session != null) {
			doSessionClose();
		}
		if (conn != null) {
			conn.close();
			conn = null;
		}
	}
	
	public void sessionBegin() throws HgRemoteConnectionException {
		if (sessionUse > 0) {
			assert session != null;
			sessionUse++;
			return;
		}
		try {
			session = conn.openSession();
			final String path = uri.getPath();
			session.execCommand(String.format("hg -R %s serve --stdio", path.charAt(0) == '/' ? path.substring(1) : path));
			remoteErr = new StreamGobbler(session.getStderr());
			remoteOut = new StreamGobbler(session.getStdout());
			remoteIn = session.getStdin();
			sessionUse = 1;
		} catch (IOException ex) {
			throw new HgRemoteConnectionException("Failed to create ssh session", ex);
		}
	}
	
	public void sessionEnd() throws HgRemoteConnectionException {
		assert sessionUse > 0;
		assert session != null;
		if (sessionUse > 1) {
			sessionUse--;
			return;
		}
		doSessionClose();
	}

	public String getCapabilities() throws HgRemoteConnectionException {
		try {
			consume(remoteOut);
			consume(remoteErr);
			remoteIn.write(CMD_HELLO.getBytes());
			remoteIn.write('\n');
			remoteIn.write(CMD_CAPABILITIES.getBytes()); // see http connector for details
			remoteIn.write('\n');
			remoteIn.write(CMD_HEADS.getBytes());
			remoteIn.write('\n');
			checkError();
			int responseLen = readResponseLength();
			checkError();
			FilterStream s = new FilterStream(remoteOut, responseLen);
			BufferedReader r = new BufferedReader(new InputStreamReader(s));
			String line;
			while ((line = r.readLine()) != null) {
				if (line.startsWith(CMD_CAPABILITIES) && line.length() > (CMD_CAPABILITIES.length()+1)) {
					line = line.substring(CMD_CAPABILITIES.length());
					if (line.charAt(0) == ':') {
						return line.substring(CMD_CAPABILITIES.length() + 1);
					}
				}
			}
			r.close();
			consume(remoteOut);
			checkError();
			return new String();
		} catch (IOException ex) {
			throw new HgRemoteConnectionException("Failed to initiate dialog with server", ex).setRemoteCommand(CMD_HELLO).setServerInfo(getServerLocation());
		}
	}

	public InputStream heads() throws HgRemoteConnectionException {
		return executeCommand("heads", Collections.<Parameter>emptyList(), true);
	}
	
	public InputStream between(Collection<Range> ranges) throws HgRemoteConnectionException {
		StringBuilder sb = new StringBuilder(ranges.size() * 82);
		for (Range r : ranges) {
			r.append(sb).append(' ');
		}
		if (!ranges.isEmpty()) {
			sb.setLength(sb.length() - 1);
		}
		return executeCommand("between", Collections.singletonList(new Parameter("pairs", sb.toString())), true);
	}
	
	public InputStream branches(List<Nodeid> nodes) throws HgRemoteConnectionException {
		String l = join(nodes, ' ');
		return executeCommand("branches", Collections.singletonList(new Parameter("nodes", l)), true);
	}
	
	public InputStream changegroup(List<Nodeid> roots) throws HgRemoteConnectionException, HgRuntimeException {
		String l = join(roots, ' ');
		InputStream cg = executeCommand("changegroup", Collections.singletonList(new Parameter("roots", l)), false);
		InputStream prefix = new ByteArrayInputStream("HG10UN".getBytes());
		return new SequenceInputStream(prefix, cg);
	}

	public OutputStream unbundle(long outputLen, List<Nodeid> remoteHeads) throws HgRemoteConnectionException, HgRuntimeException {
		String l = join(remoteHeads, ' ');
		try {
			consume(remoteOut);
			consume(remoteErr);
			remoteIn.write(CMD_UNBUNDLE.getBytes());
			remoteIn.write('\n');
			writeParameters(Collections.singletonList(new Parameter("heads", l)));
			checkError();
			return new FilterOutputStream(remoteIn) {
				@Override
				public void close() throws IOException {
					out.flush();
					@SuppressWarnings("unused")
					int responseLen = readResponseLength();
					checkError();
					// XXX perhaps, need to return responseLen to caller? 
				}
			};
		} catch (IOException ex) {
			throw new HgRemoteConnectionException("Communication failure", ex).setRemoteCommand(CMD_UNBUNDLE).setServerInfo(getServerLocation());
		}
	}

	public InputStream pushkey(String opName, String namespace, String key, String oldValue, String newValue) throws HgRemoteConnectionException, HgRuntimeException {
		ArrayList<Parameter> p = new ArrayList<Parameter>();
		p.add(new Parameter("namespace", namespace));
		p.add(new Parameter("key", key));
		p.add(new Parameter("old", oldValue));
		p.add(new Parameter("new", newValue));
		return executeCommand("pushkey", p, true);
	}
	
	public InputStream listkeys(String namespace, String actionName) throws HgRemoteConnectionException, HgRuntimeException {
		return executeCommand("listkeys", Collections.singletonList(new Parameter("namespace", namespace)), true);
	}
	
	private InputStream executeCommand(String cmd, List<Parameter> parameters, boolean expectResponseLength) throws HgRemoteConnectionException {
		try {
			consume(remoteOut);
			consume(remoteErr);
			remoteIn.write(cmd.getBytes());
			remoteIn.write('\n');
			writeParameters(parameters);
			checkError();
			if (expectResponseLength) {
				int responseLen = readResponseLength();
				checkError();
				return new FilterStream(remoteOut, responseLen);
			} else {
				return new FilterStream(remoteOut, Integer.MAX_VALUE);
			}
		} catch (IOException ex) {
			throw new HgRemoteConnectionException("Communication failure", ex).setRemoteCommand(cmd).setServerInfo(getServerLocation());
		}
	}
	
	private void writeParameters(List<Parameter> parameters) throws IOException {
		for (Parameter p : parameters) {
			remoteIn.write(p.name().getBytes());
			remoteIn.write(' ');
			remoteIn.write(String.valueOf(p.size()).getBytes());
			remoteIn.write('\n');
			remoteIn.write(p.data());
			remoteIn.write('\n');
		}
	}

	private void consume(InputStream is) throws IOException {
		while (is.available() > 0) {
			is.read();
		}
	}

	private void checkError() throws IOException {
		if (remoteErr.available() > 0) {
			StringBuilder sb = new StringBuilder();
			int c;
			while ((c = remoteErr.read()) != -1) {
				sb.append((char)c);
			}
			throw new IOException(sb.toString());
		}
	}
	
	private int readResponseLength() throws IOException {
		int c;
		StringBuilder sb = new StringBuilder();
		while ((c = remoteOut.read()) != -1) {
			if (c == '\n') {
				break;
			}
			sb.append((char) c);
		}
		if (c == -1) {
			throw new EOFException();
		}
		try {
			return Integer.parseInt(sb.toString());
		} catch (NumberFormatException ex) {
			throw new IOException(String.format("Expected response length instead of %s", sb));
		}
	}


	private void doSessionClose() {
		if (session != null) {
			closeQuietly(remoteErr);
			closeQuietly(remoteOut);
			remoteErr = remoteOut = null;
			closeQuietly(remoteIn);
			remoteIn = null;
			session.close();
			session = null;
		}
		sessionUse = 0;
	}

	private void closeQuietly(Closeable c) {
		try {
			if (c != null) {
				c.close();
			}
		} catch (IOException ex) {
			sessionCtx.getLog().dump(getClass(), Severity.Warn, ex, null);
		}
	}

	private static String join(List<Nodeid> values, char sep) {
		StringBuilder sb = new StringBuilder(values.size() * 41);
		for (Nodeid n : values) {
			sb.append(n.toString());
			sb.append(sep);
		}
		if (!values.isEmpty()) {
			// strip last space 
			sb.setLength(sb.length() - 1);
		}
		return sb.toString();
	}

	private static final class Parameter {
		private final String name;
		private final byte[] data;

		public Parameter(String paramName, String paramValue) {
			assert paramName != null;
			assert paramValue != null;
			name = paramName;
			data = paramValue.getBytes();
		}

		public String name() {
			return name;
		}
		public int size() {
			return data.length;
		}
		public byte[] data() {
			return data;
		}
	}

	private static final class FilterStream extends FilterInputStream {
		private int length;

		public FilterStream(InputStream is, int initialLength) {
			super(is);
			length = initialLength;
		}
		
		@Override
		public int available() throws IOException {
			return Math.min(super.available(), length);
		}
		@Override
		public int read() throws IOException {
			if (length == 0) {
				return -1;
			}
			int r = super.read();
			if (r >= 0) {
				length--;
			}
			return r;
		}
		@Override
		public int read(byte[] b, int off, int len) throws IOException {
			if (length == 0) {
				return -1;
			}
			int r = super.read(b, off, Math.min(len, length));
			if (r >= 0) {
				assert r <= length;
				length -= r;
			}
			return r;
		}
		@Override
		public void close() throws IOException {
			length = 0;
			// INTENTIONALLY DOES NOT CLOSE THE STREAM
		}
	}
}