# -*- coding: utf-8 -*-
"""Classes used in loon package"""
import sys
import os
import json
import socket
import glob
import re
import io
from getpass import getpass
from subprocess import run, PIPE
from datetime import datetime
from shutil import copyfile
from ssh2.session import Session
if __package__ == '' or __package__ is None: # Use for test
from __init__ import __host_file__
from utils import create_parentdir, isfile, isdir, pretty_table, get_filelist, read_csv
else:
from loon import __host_file__
from loon.utils import create_parentdir, isfile, isdir, pretty_table, get_filelist, read_csv
this_file = os.path.realpath(__file__)
this_dir = os.path.dirname(this_file)
data_dir = os.path.join(this_dir, 'data')
[docs]class Host:
"""
Representation of remote host
"""
def __init__(self, hostfile=__host_file__):
self.hostfile = hostfile
self.load_hosts()
return
[docs] def load_hosts(self):
"""Load hosts from file"""
if not isfile(self.hostfile):
self.active_host = []
self.available_hosts = []
else:
with open(self.hostfile, 'r') as f:
hosts = json.load(f)
self.active_host = hosts['active']
self.available_hosts = hosts['available']
if any(isinstance(i, list) for i in self.active_host):
print(
"Error: more than one active host. Please check config file ~/.config/loon/host.json and modify or remove it if necessary."
)
# Python code to remove duplicate elements
def RemoveDups(duplicate):
final_list = []
flag = False
for num in duplicate:
if num not in final_list:
final_list.append(num)
else:
flag = True
return final_list, flag
self.available_hosts, flag = RemoveDups(self.available_hosts)
if flag:
# Save unique hosts immediately
self.save_hosts()
return
[docs] def save_hosts(self):
"""Save hosts to file"""
# if len(self.active_host)==0 or len(self.available_hosts)==0:
# raise ValueError("Cannot save to file due to null host.")
hosts = {'active': self.active_host, 'available': self.available_hosts}
if not isfile(self.hostfile):
# Create parent dir if hostfile does not exist
create_parentdir(self.hostfile)
with open(self.hostfile, 'w') as f:
json.dump(hosts, f)
return
[docs] def add(self, name, username, host, port=22, dry_run=False):
"""Add a remote host
Args:
name: hostname alias, a string
username: hostname, a string
host: host ip address, a string
port: host ip port, an integer
dry_run: if `True`, dry run the code
Returns:
None
"""
info = [name, username, host, port]
if dry_run:
print("=> Running add", tuple(info[1:]))
sys.exit(0)
if info in self.available_hosts:
print("=> Input host exists. Will not change.")
return
else:
self.available_hosts.append(info)
if len(self.active_host) == 0:
self.active_host = info
self.save_hosts()
print("=> Added successfully!")
return
[docs] def host_check(self, name, username, host, port=22):
"""Check if a host exists
Args:
name: hostname alias, a string
username: hostname, a string
host: host ip address, a string
port: host ip port, an integer
Returns:
a list representing the host
"""
host = []
if name is not None:
for h in self.available_hosts:
if h[0] == name:
host = h.copy()
else:
info = [username, host, port]
for h in self.available_hosts:
if h[1:] == info:
host = h.copy()
if len(host) == 0:
print(
"=> Host does not exist, please check input with list command!"
)
sys.exit(1)
return host
[docs] def delete(self, name, username, host, port=22, dry_run=False):
"""Delete a remote host
Args:
name: hostname alias, a string
username: hostname, a string
host: host ip address, a string
port: host ip port, an integer
dry_run: if `True`, dry run the code
Returns:
None
"""
if dry_run:
print("Running delete", (username, host, port))
sys.exit(0)
host2del = self.host_check(name, username, host, port)
print("=> Removing host from available list...")
self.available_hosts.remove(host2del)
if host2del == self.active_host:
print("=> Removing active host...")
if len(self.available_hosts) > 0:
self.active_host = self.available_hosts[0]
print("=> Changing active host to %s" % self.active_host[0])
else:
self.active_host = [] # reset
print("=> Reseting active host to []")
self.save_hosts()
return
[docs] def switch(self, name, username, host, port=22, dry_run=False):
"""Switch active host
Args:
name: hostname alias, a string
username: hostname, a string
host: host ip address, a string
port: host ip port, an integer
dry_run: if `True`, dry run the code
Returns:
None
"""
if dry_run:
print("Running switch",
(username, host, port) if username is not None else name)
sys.exit(0)
host2switch = self.host_check(name, username, host, port)
self.active_host = host2switch
self.save_hosts()
print("=> %s activated." % name)
return
[docs] def rename(self, old, new, dry_run=False):
"""Rename host name
Args:
old: a string representing the old host name alias
new: a string representing the new host name alias
dry_run: if `True`, dry run the code
Returns:
None
"""
if dry_run:
print("Running rename", old, "to", new)
sys.exit(0)
host2rename = []
for index, h in enumerate(self.available_hosts):
if h[0] == old:
host2rename = h.copy()
self.available_hosts[index][0] = new
if len(host2rename) == 0:
print(
"=> Host does not exist, please check input with list command!"
)
sys.exit(1)
if host2rename == self.active_host:
self.active_host[0] = new
self.save_hosts()
return
[docs] def list(self):
"""List all remote hosts"""
title = ['Alias', 'Username', 'IP address', 'Port']
content = self.available_hosts.copy()
for host in content:
if host == self.active_host:
host[0] = '<' + host[0] + '>'
pretty_table(title, content)
print("<active host>")
return
[docs] def connect(self,
privatekey_file="~/.ssh/id_rsa",
passphrase='',
open_channel=True):
"""Connect active host and open a session
Args:
privatekey_file: a string representing the path to the private key file
passphrase: a string representing the password
open_channel: if `True`, open the SSH channel
Returns:
None
"""
privatekey_file = os.path.expanduser(privatekey_file)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((self.active_host[2], self.active_host[3]))
s = Session()
s.handshake(sock)
try:
# Try using private key file first
s.userauth_publickey_fromfile(self.active_host[1], privatekey_file,
passphrase)
except:
# Use password to auth
passwd = getpass(
'No private key found.\nEnter your password for %s: ' %
self.active_host[1])
s.userauth_password(self.active_host[1], passwd)
self.session = s
if open_channel:
self.channel = self.session.open_session()
return
[docs] def cmd(self,
commands,
_logger=None,
run_file=False,
data_dir=None,
remote_file=False,
dir='/tmp',
prog=None,
dry_run=False):
"""Run command(s) in active remote host using channel session
Therefore, `open_channel` in `connect` method must be `True` before using it.
Args:
commands: commands/scripts run on active remote host
_logger: the logging logger
run_file: if `True`, run scripts instead of commands
data_dir: a path representing data directory
remote_file: if `True`, collect input from remote host instead of local machine
dir: Remote directory for storing local scripts
prog: a string representing the program to run the commands
dry_run: if `True`, dry run the code
Returns:
A string containing result information
"""
if dry_run:
print("Running", "files:" if run_file else "commands:", commands)
sys.exit(0)
if not run_file:
self.connect()
self.channel.execute(commands)
else:
# Run scripts
_logger.info(commands)
scripts = commands
# commands are scripts here
if remote_file:
# Run remote scripts
# Support some wildcards
# *,?,{}
wildcards = r'\*|\?|\{\}'
matches = [
re.compile(wildcards).search(i) is not None
for i in scripts
]
if any(matches):
commands_1 = list(map(lambda x: 'ls ' + x, scripts))
commands_1 = ';'.join(commands_1)
self.connect()
self.channel.execute(commands_1)
scripts = self.get_result(print_info=False).split('\n')
if '' in scripts:
scripts.remove('')
if prog is None:
commands_1 = list(map(lambda x: 'chmod u+x ' + x, scripts))
commands_1 = ';'.join(commands_1)
commands_2 = ';'.join(scripts)
commands = commands_1 + ';' + commands_2
else:
commands = list(
map(lambda x: '{} '.format(prog) + x, scripts))
commands = ';'.join(commands)
_logger.info(commands)
self.connect()
print("=> Getting results:")
self.channel.execute(commands)
else:
# Run local scripts
#
# 1) upload
self.upload(scripts, dir, _logger)
if data_dir is not None:
self.upload(data_dir, dir, _logger)
# 2) get all file names
if len(scripts) == 1:
if isdir(scripts[0]):
if list(scripts[0])[-1] == '/':
dir = os.path.join(
dir,
os.path.basename(os.path.dirname(scripts[0])))
else:
dir = os.path.join(dir,
os.path.basename(scripts[0]))
scripts = glob.glob(scripts[0] + '/*')
filelist = []
for fp in scripts:
_logger.info("fp:%s" % fp)
fs = glob.glob(fp)
_logger.info("fs:%s" % fs)
for f in fs:
_logger.info("f:%s" % f)
if isdir(f):
print(
"Warning: directory %s is detected, note anything in it will be ignored to execute."
% f)
elif isfile(f):
filelist.append(f)
else:
print('Error: file %s does not exist.' % f)
sys.exit(1)
filelist = list(map(os.path.basename, filelist))
_logger.info(filelist)
# 3) run them one by one
scripts = list(map(lambda x: '/'.join([dir, x]), filelist))
if prog is None:
commands_1 = list(map(lambda x: 'chmod u+x ' + x, scripts))
commands_1 = ';'.join(commands_1)
commands_2 = ';'.join(scripts)
commands = commands_1 + ';' + commands_2
else:
commands = list(
map(lambda x: '{} '.format(prog) + x, scripts))
commands = ';'.join(commands)
_logger.info(commands)
self.connect()
print("=> Getting results:")
self.channel.execute(commands)
datalist = self.get_result()
# Return a string containing output
return "".join(datalist)
[docs] def get_result(self, print_info=True):
"""Get result from executed channel
Args:
print_info: if `True`, print information
Returns:
a string containing output from executed commands
"""
size, errinfo = self.channel.read_stderr()
if size > 0:
print('An error is raised by remote host, please read the info:\n')
print(errinfo.decode('utf-8', errors='replace'), end="")
sys.exit(1)
else:
# Get output
datalist = []
size, data = self.channel.read()
# Here data is byte type
while size > 0:
data = data.decode('utf-8', errors='ignore')
if print_info:
print(data, sep='', end='')
datalist.append(data)
size, data = self.channel.read()
# Return a string containing output from commands
return "".join(datalist)
[docs] def upload(self,
source,
destination,
_logger,
use_rsync=False,
dry_run=False):
"""Upload files to active remote host.
Currently, it is dependent on scp command.
Args:
source: list of files (directories) in local machine
destination: destination directory in remote host
_logger: the logging logger
use_rsync: if `True`, use rsync instead of scp
dry_run: if `True`, dry run the code
Returns:
None
"""
username, host, port = self.active_host[1:]
if dry_run:
print("Running upload", ' '.join(source), "to", destination, "on",
tuple(self.active_host[1:]))
sys.exit(0)
# Make sure scp/rsync recognize destination as directory
# Path must end with '/'
if list(destination)[-1] != '/':
destination = destination + '/'
if use_rsync:
if sys.platform == 'win32':
print("--rsync is disabled in Windows, please don't use it.")
sys.exit(0)
cmds = "rsync -azP -e 'ssh -p {port}' {source} {username}@{host}:{destination}".format(
port=port,
source=' '.join(map(os.path.expanduser, source)),
username=username,
host=host,
destination=destination)
else:
cmds = "scp -pr -P {port} {source} {username}@{host}:{destination}".format(
port=port,
source=' '.join(map(os.path.expanduser, source)),
username=username,
host=host,
destination=destination)
print("=> Starting upload...", end="\n\n")
now = datetime.now()
_logger.info("Running " + cmds)
run_res = run(cmds, shell=True)
_logger.info("Status code: " + str(run_res.returncode))
if run_res.returncode != 0:
print("Error: an error occurred, please check the info!")
sys.exit(run_res.returncode)
taken = datetime.now() - now
print("\n=> Finished uploading in %ss" % taken.seconds)
return
[docs] def download(self,
source,
destination,
_logger,
use_rsync=False,
dry_run=False):
"""Download files to local machine from active remote host.
Currently, it is dependent on scp command.
Args:
source: list of files (directories) in remote host
destination: destination directory in local machine
_logger: the logging logger
use_rsync: if `True`, use rsync instead of scp
dry_run: if `True`, dry run the code
Returns:
None
"""
username, host, port = self.active_host[1:]
if dry_run:
print("Running download", ' '.join(source), "to", destination,
"from", tuple(self.active_host[1:]))
sys.exit(0)
if not isdir(os.path.expanduser(destination)):
os.makedirs(os.path.expanduser(destination))
# Make sure scp/rsync recognize destination as directory
# Path must end with '/'
if list(destination)[-1] != '/':
destination = destination + '/'
print("=> Starting downloading...", end="\n\n")
now = datetime.now()
if use_rsync:
if sys.platform == 'win32':
print("--rsync is disabled in Windows, please don't use it.")
sys.exit(0)
cmds = "rsync -azP -e 'ssh -p {port}' {username}@{host}:'{source}' {destination}".format(
port=port,
source=' '.join(source),
username=username,
host=host,
destination=os.path.expanduser(destination))
else:
cmds = "scp -pr -P {port} {username}@{host}:'{source}' {destination}".format(
port=port,
source=' '.join(source),
username=username,
host=host,
destination=os.path.expanduser(destination))
_logger.info("Running " + cmds)
run_res = run(cmds, shell=True)
_logger.info("Status code: " + str(run_res.returncode))
if run_res.returncode != 0:
print("Error: an error occurred, please check the info!")
sys.exit(run_res.returncode)
taken = datetime.now() - now
print("\n=> Finished downloading in %ss" % taken.seconds)
return
[docs]class PBS:
"""
Representation of PBS task
"""
def __init__(self):
self.tmp_header = os.path.join(data_dir, "PBS_HEADER.txt")
self.tmp_cmds = os.path.join(data_dir, "PBS_CMDS.txt")
self.pbs_template = os.path.join(data_dir, "pbs-template.pbs")
self.samplefile = os.path.join(data_dir, "samplefile.csv")
self.mapfile = os.path.join(data_dir, "mapping.csv")
return
[docs] def gen_template(self, input, output, dry_run=False):
"""Generate a PBS template
Args:
input: a string representing the path to template file
output: a string representing the path to output file
dyr_run: if `True`, dry run the code
Returns:
None
"""
if output is None:
output = os.path.join(os.getcwd(), 'work.pbs')
print("=> Generating %s" % output)
if dry_run:
sys.exit(0)
if isfile(output):
print("Warning: the output file exists, it will be overwritten.")
if input is None:
with io.open(output, 'w', encoding='utf-8', newline='\n') as f:
with open(self.tmp_header, 'r') as header:
for i in header:
print(i, file=f, sep='', end="")
with io.open(output, 'a', encoding='utf-8', newline='\n') as f:
with open(self.tmp_cmds, 'r') as cmds:
for i in cmds:
print(i, file=f, sep='', end="")
else:
if not isfile(input):
print("Error: cannot find the template file.")
sys.exit(1)
with io.open(output, 'w', encoding='utf-8', newline='\n') as f:
with open(input, 'r') as inf:
for i in inf:
print(i, file=f, sep='', end="")
print("=> Done.")
return
[docs] def gen_pbs(self,
template,
samplefile,
mapfile,
outdir,
_logger,
pbs_mode=True,
dry_run=False):
"""Generate a batch of (script) files (PBS tasks) based on template and mapping file
Args:
template: a string representing the path to the template file
samplefile: a string representing the path to the sample file
mapfile: a string representing the path to the mapping file
outdir: a string representing the path to output directory
_logger: the logging logger
pbs_mode: if `True`, use PBS mode
dry_run: if `True`, dry run the code
Returns:
None
"""
if not isdir(outdir):
print("Directory %s does not exist, creating it" % outdir)
os.makedirs(outdir)
if not isfile(template):
print("Error: file %s does not exist" % template)
if not isfile(samplefile):
print("Error: file %s does not exist" % samplefile)
if not isfile(mapfile):
print("Error: file %s does not exist" % mapfile)
print("=====================")
print("Output path : " + outdir)
if pbs_mode:
print("PBS Template: " + template)
else:
print("Template: " + template)
print("Sample file : " + samplefile)
print("Mapping file: " + mapfile)
print("=====================")
if dry_run:
sys.exit(0)
print("=> Reading %s ..." % samplefile)
sample_data = read_csv(samplefile)
print("=> Reading %s ..." % mapfile)
map_data = read_csv(mapfile)
# Check if input files are valid
check_list = [i[0] for i in sample_data]
check_list = set(check_list)
if len(sample_data) != len(check_list):
print("Error: the first column is not unique!")
sys.exit(1)
for row in map_data:
if len(row) != 2:
print("Error: only two columns are quired in mapfile!")
try:
_ = int(row[1])
except Exception:
print(
"Error: the second column must be (or can be transformed to) an integer!"
)
sys.exit(1)
print("=> Reading %s ..." % template)
with open(template, 'r') as f:
temp_data = f.read()
print("Generating...")
for row in sample_data:
if pbs_mode:
pbsfile = os.path.join(outdir, row[0] + '.pbs')
else:
pbsfile = os.path.join(outdir, row[0])
_logger.info("Generating %s" % pbsfile)
content = temp_data
for i in map_data:
try:
_logger.info("Replacing %s with %s" %
(i[0], row[int(i[1])]))
content = content.replace(i[0], row[int(i[1])])
except Exception:
print(
"Error: the second column out of range for label %s!" %
i[0])
with io.open(pbsfile, 'w', encoding='utf-8', newline='\n') as f:
f.write(content)
print("Done.")
return
[docs] def gen_pbs_example(self, outdir, _logger, dry_run=False):
"""Generate example files for pbsgen command to specified directory
Args:
outdir: a string representing the output directory
_logger: the logging logger
dry_run: if `True`, dry run the code
Returns:
None
"""
if not isdir(outdir):
print("Directory %s does not exist, creating it" % outdir)
os.makedirs(outdir)
pbs_template = os.path.join(outdir,
os.path.basename(self.pbs_template))
samplefile = os.path.join(outdir, os.path.basename(self.samplefile))
mapfile = os.path.join(outdir, os.path.basename(self.mapfile))
print("=====================")
print("Output path : " + outdir)
print("PBS Template: " + pbs_template)
print("Sample file : " + samplefile)
print("Mapping file: " + mapfile)
print("=====================")
if dry_run:
sys.exit(0)
copyfile(self.pbs_template, pbs_template)
copyfile(self.samplefile, samplefile)
copyfile(self.mapfile, mapfile)
print("Done.")
return
[docs] def sub(self, host, tasks, remote, workdir, _logger, dry_run=False):
"""Submit pbs tasks
Args:
host: a host object
tasks: a list of PBS files, glob pattern is supported
remote: if `True`, means that PBS task files are located at the active remote host
workdir: a directory representing the working directory
_logger: the logging logger
dry_run: if `True`, dry run the code
Returns:
A list of files
"""
print('NOTE: PBS file must be LF mode (Unix), not CRLF mode (Windows)')
print('====================================================')
filelist = []
if remote:
tasks = ' '.join(tasks)
host.connect()
_logger.info('ls -p ' + tasks)
host.channel.execute('ls -p ' + tasks)
filelist = host.get_result(print_info=False).split('\n')
if '' in filelist:
filelist.remove('')
fl_bk = filelist.copy()
for f in fl_bk:
if len(f) > 1 and (f[-1] == '/' or f[-1] == ':'):
filelist.remove(f)
if f == '' or f == ' ':
filelist.remove(f)
_logger.info(filelist)
if workdir is None:
workdir = '/tmp'
cmds = 'cd {}; for i in {}; do qsub $i; done'.format(
workdir, ' '.join(filelist))
if dry_run:
print(cmds)
sys.exit(0)
_logger.info(cmds)
host.cmd(cmds, _logger=_logger)
else:
if workdir is None:
workdir = os.getcwd()
for fp in tasks:
fs = glob.glob(fp)
for f in fs:
if isdir(f):
print(
"Warning: directory %s is detected, note anything in it will be ignored to execute."
% f)
elif isfile(f):
filelist.append(f)
cmds = 'cd ' + workdir + ';qsub ' + f
if dry_run:
print(cmds)
else:
_logger.info(cmds)
run(cmds, shell=True)
else:
print('Error: file %s does not exist.' % f)
sys.exit(1)
return filelist
[docs] def deploy(self,
host,
source,
destination,
_logger,
use_rsync=False,
dry_run=False):
"""Deploy target directory on the active remote host
Upload the target destination and then submit all *.pbs files.
Args:
host: a host object
source: a string representing the directory (contains .pbs files) to upload
destination: a string representing the path on remote host
_logger: the logging logger
use_rsync: if `True`, use rsync instead of scp
dry_run: if `True`, dry run the code
Returns:
None
"""
if destination is None:
destination = '/tmp'
if dry_run:
print("Running deploy", source, "to", destination, "on",
tuple(host.active_host[1:]))
sys.exit(0)
if not isdir(source):
print("Error: directory %s does not exist" % source)
sys.exit(1)
source = [source]
host.upload(source, destination, _logger, use_rsync=use_rsync)
self.sub(host, [destination + '/*.pbs'], True, destination, _logger)
return
[docs] def check(self, host, job_id, dry_run=False):
"""Check PBS task status
Args:
host: a host object
job_id: a string the job id
dry_run: if `True`, dry run the code
Returns:
Job status
"""
if job_id is None:
if dry_run:
print("Running qstat on", tuple(host.active_host[1:]))
sys.exit(0)
return host.cmd('qstat')
else:
if dry_run:
print("Running qstat", job_id, "on",
tuple(host.active_host[1:]))
sys.exit(0)
return host.cmd('qstat ' + job_id)
if __name__ == "__main__":
print(this_dir)
print(data_dir)