Skip to content

Commit 4f32dd9

Browse files
committed
Adding authenticated read / write
1 parent b3524d4 commit 4f32dd9

1 file changed

Lines changed: 53 additions & 27 deletions

File tree

stashcp.py

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import socket
1313
import random
1414
import shutil
15+
from urlparse import urlparse
1516

1617
try:
1718
from pkg_resources import resource_string
@@ -27,7 +28,8 @@
2728

2829
main_redirector = "root://redirector.osgstorage.org"
2930
stash_origin = "root://stash.osgconnect.net"
30-
writeback_host = "http://stash-xrd.osgconnect.net:1094"
31+
writeback_host = "http://redirector.osgstorage.org"
32+
#writeback_host = "http://stash-xrd.osgconnect.net:1094"
3133

3234
# Global variable for nearest cache
3335
nearest_cache = None
@@ -38,6 +40,9 @@
3840
# Global variable for the location of the caches.json file
3941
caches_json_location = None
4042

43+
# Global variable for the location of the token to use for reading / writing
44+
token_location = None
45+
4146
TIMEOUT = 300
4247
DIFF = TIMEOUT * 10
4348

@@ -49,27 +54,12 @@ def doWriteBack(source, destination):
4954
:param str destination: The location of the remote file, in stash:// format
5055
"""
5156
start1 = int(time.time()*1000)
52-
53-
# Get the scitoken content
54-
scitoken_file = None
55-
if '_CONDOR_CREDS' in os.environ:
56-
# First, look for the scitokens.use file
57-
# Format: _CONDOR_CREDS=/var/lib/condor/execute/dir_908/.condor_creds
58-
scitoken_file = os.path.join(os.environ['_CONDOR_CREDS'], 'scitokens.use')
59-
if not os.path.exists(scitoken_file):
60-
scitoken_file = None
61-
62-
if not scitoken_file and os.path.exists(".condor_creds/scitokens.use"):
63-
scitoken_file = ".condor_creds/scitokens.use"
64-
65-
if not scitoken_file:
57+
58+
scitoken_contents = getToken()
59+
if scitoken_contents is None:
6660
logging.error("Unable to find scitokens.use file")
6761
return 1
6862

69-
70-
with open(scitoken_file, 'r') as scitoken_obj:
71-
scitoken_contents = scitoken_obj.read().strip()
72-
7363
# Remove the stash:// at the beginning, don't need it
7464
destination = destination.replace("stash://", "")
7565

@@ -119,6 +109,26 @@ def doWriteBack(source, destination):
119109
es_send(payload)
120110
return curl_exit
121111

112+
def getToken():
113+
"""
114+
Get the token / scitoken from the environment in order to read / write
115+
"""
116+
# Get the scitoken content
117+
scitoken_file = None
118+
if token_location:
119+
scitoken_file = token_location
120+
121+
if 'TOKEN' in os.environ:
122+
scitoken_file = os.environ['TOKEN']
123+
124+
if not scitoken_file or not os.path.exists(scitoken_file):
125+
logging.info("Unable to find token file")
126+
return None
127+
128+
with open(scitoken_file, 'r') as scitoken_obj:
129+
scitoken_contents = scitoken_obj.read().strip()
130+
131+
return scitoken_contents
122132

123133
def doStashCpSingle(sourceFile, destination, methods, debug=False):
124134
"""
@@ -299,11 +309,10 @@ def download_http(source, destination, debug, payload):
299309
global nearest_cache
300310
global nearest_cache_list
301311

302-
if not nearest_cache:
303-
nearest_cache = get_best_stashcache()
304-
305312
logging.debug("Downloading with HTTP")
306313

314+
scitoken_contents = getToken()
315+
307316
if not nearest_cache:
308317
nearest_cache = get_best_stashcache()
309318

@@ -340,11 +349,17 @@ def download_http(source, destination, debug, payload):
340349
cache = cache.replace('root://', 'http://')
341350

342351
# Append port 8000, which is just a convention for now, not set in stone
343-
cache += ":8000"
352+
# Check if the cache already has a port attached to it
353+
parsed_url = urlparse(cache)
354+
if not parsed_url.port:
355+
cache += ":8000"
344356

345357
# Quote the source URL, which may have weird, dangerous characters
346358
quoted_source = urllib2.quote(source)
347-
curl_command = "curl %s -L --connect-timeout 30 --speed-limit 1024 %s --fail %s%s" % (output_mode, download_output, cache, quoted_source)
359+
if scitoken_contents:
360+
curl_command = "curl %s -L --connect-timeout 30 --speed-limit 1024 %s --fail -H \"Authorization: Bearer %s\" %s%s" % (output_mode, download_output, scitoken_contents, cache, quoted_source)
361+
else:
362+
curl_command = "curl %s -L --connect-timeout 30 --speed-limit 1024 %s --fail %s%s" % (output_mode, download_output, cache, quoted_source)
348363
logging.debug("About to run curl command: %s", curl_command)
349364
start = int(time.time()*1000)
350365
command_object = subprocess.Popen([curl_command], shell=True, cwd=dest_dir)
@@ -591,6 +606,7 @@ def main():
591606
global nearest_cache
592607
global nearest_cache_list
593608
global caches_json_location
609+
global token_location
594610

595611
usage = "usage: %prog [options] source destination"
596612
parser = optparse.OptionParser(usage)
@@ -601,6 +617,7 @@ def main():
601617
parser.add_option('-j', '--caches-json', dest='caches_json', help="The JSON file containing the list of caches",
602618
default=None)
603619
parser.add_option('--methods', dest='methods', help="Comma separated list of methods to try, in order. Default: cvmfs,xrootd,http", default="cvmfs,xrootd,http")
620+
parser.add_option('-t', '--token', dest='token', help="Token file to use for reading and/or writing")
604621
args,opts=parser.parse_args()
605622

606623
logging.basicConfig(format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
@@ -613,7 +630,10 @@ def main():
613630
else:
614631
logger.setLevel(logging.WARNING)
615632

616-
caches_json_location = args.caches_json
633+
if 'CACHES_JSON' in os.environ:
634+
caches_json_location = os.environ['CACHES_JSON']
635+
else:
636+
caches_json_location = args.caches_json
617637
if args.closest:
618638
print get_best_stashcache()
619639
sys.exit(0)
@@ -625,9 +645,15 @@ def main():
625645
destination=opts[1]
626646

627647
# Check for manually entered cache to use
628-
if args.cache and len(args.cache) > 0:
648+
if 'NEAREST_CACHE' in os.environ:
649+
nearest_cache = os.environ['NEAREST_CACHE']
650+
nearest_cache_list = [nearest_cache]
651+
elif args.cache and len(args.cache) > 0:
629652
nearest_cache = args.cache
630-
nearest_cache_list = [ args.cache ]
653+
nearest_cache_list = [args.cache]
654+
655+
if args.token:
656+
token_location = args.token
631657

632658
# Convert the methods
633659
methods = args.methods.split(',')

0 commit comments

Comments
 (0)