Repository

Take an url list instead of a hosts list, add a proper url parser

Parent commits : 6ca417763c62531b1d37eea6f35f5e1537094cda,
Children commits :

By Laurent Defert on 2013-06-01 10:53:09
Take an url list instead of a hosts list, add a proper url parser

Difference with parent commit 6ca417763c62531b1d37eea6f35f5e1537094cda
Files modified:
httpget.py
--- 
+++ 
@@ -8,7 +8,6 @@
 import threading
 import time
 from  urllib import unquote
-import urlparse
 
 # TODO: Remove hard-coded configuration file, it hould be generated based on
 # the command argument.
@@ -16,7 +15,7 @@
 DEFAULT_CONF = '''# httpget default configuration file
 [client]
 # Hosts list
-hosts_list = 127.0.0.1:8080
+urls_list = 127.0.0.1:8080
 
 # Concurrent connections
 # nconn = 1
@@ -170,40 +169,96 @@
             print e
         self.lock.release()
 
-class HTTPGet(threading.Thread):
-    def __init__(self, hdd_file, host, raw_url, chunk_no, chunk_count, filesize, progress_bar):
-        name = raw_url
+class HTTPConnection(object):
+    def parse_url(self, url):
+        self.url = url
+        self.scheme = 'http'
+        self.host = None
+        self.port = 80
+        self.user = ''
+        self.password = ''
+        self.path = ''
+        self.filename = 'index.html'
+
+        self.scheme = 'http'
+        if '://' in url:
+            self.scheme, url = url.split('://', 1)
+            if self.scheme == 'https':
+                self.port = 443
+
+        url = url.replace('//', '/').strip('/')
+
+        if not '/' in url:
+            self.host = url
+            url = ''
+        else:
+            self.host, url = url.split('/', 1)
+
+        if '@' in self.host:
+            self.user, self.host = self.host.split('@', 1)
+
+        if ':' in self.user:
+            self.user, self.password = self.user.split(':', 1)
+
+        if ':' in self.host:
+            self.host, self.port = self.host.split(':', 1)
+            self.port = int(self.port)
+
+        self.path = '/' + url
+
+        if url != '':
+            self.filename = unquote(self.path)
+
+        if '/' in self.filename:
+            self.filename = self.filename.rsplit('/', 1)[1]
+
+    def get_connection(self):
+        if self.scheme == 'http':
+            return httplib.HTTPConnection(self.host, self.port)
+        elif self.scheme == 'https':
+            return httplib.HTTPSConnection(self.host, self.port)
+        else:
+             raise Exception('Unknown protocol: %s' % self.scheme)
+
+    def __str__(self):
+        return '%s -> %s://%s:%i%s' % (self.url, self.scheme, self.host, self.port, self.path)
+
+    def test(self):
+        print HTTPConnection().parse_url('google.fr')
+        print HTTPConnection().parse_url('google.fr/test')
+        print HTTPConnection().parse_url('google.fr/test/plop')
+        print HTTPConnection().parse_url('google.fr:81')
+        print HTTPConnection().parse_url('google.fr:81/')
+        print HTTPConnection().parse_url('google.fr:81/test').path
+        print HTTPConnection().parse_url('google.fr:81/test/plop')
+        print HTTPConnection().parse_url('https://google.fr')
+        print HTTPConnection().parse_url('https://google.fr:81')
+        print HTTPConnection().parse_url('https://google.fr/')
+        print HTTPConnection().parse_url('https://google.fr:81/')
+        print HTTPConnection().parse_url('https://google.fr:81/test')
+        print HTTPConnection().parse_url('https://google.fr:81/test/plop')
+        print HTTPConnection().parse_url('https://google.fr:81/test/plop/')
+        exit(1)
+
+class HTTPGet(threading.Thread, HTTPConnection):
+    def __init__(self, hdd_file, url, chunk_no, chunk_count, filesize, progress_bar):
+        self.parse_url(url)
+        name = url
         if chunk_no is not None:
-            name += '.part.%i' % chunk_no
+            name += ' %i' % chunk_no
         threading.Thread.__init__(self, name=name)
-        self.raw_url = raw_url
-        self.url = urlparse.urlparse(raw_url)
-        self.host = host
+
         self.chunk_no = chunk_no
         self.chunk_count = chunk_count
         self.hdd_file = hdd_file
         self.progress_bar = progress_bar
         self.filesize = filesize
-
-        self.port = '80'
-        if self.url.scheme == 'https':
-            self.port = '443'
-        if ':' in host:
-            self.host, self.port = host.rsplit(':', 1)
         self.http_status = 0
 
     def run(self):
         try:
-            if self.url.scheme == 'http':
-                get = httplib.HTTPConnection(self.host, self.port)
-            elif self.url.scheme == 'https':
-                get = httplib.HTTPSConnection(self.host, self.port)
-            else:
-                 raise GetException('Unknown protocol: %s' % self.proto)
             headers = {'User-Agent:': CLIENT_AGENT}
-            path = self.url.path
-            if path == '':
-                path = '/'
+
             if not self.chunk_no is None:
                 start = self.chunk_no * IO_SIZE
                 end = start + (self.chunk_count * IO_SIZE) - 1
@@ -215,7 +270,9 @@
                 content_range = 'bytes=%i-%i' % (start, end)
                 content_range = {'Range': content_range}
                 headers.update(content_range)
-            get.request('GET', path, headers=headers)
+
+            get = self.get_connection()
+            get.request('GET', self.path, headers=headers)
             get = get.getresponse()
             if not get.status in (200, 206):
                 raise GetException('HTTP %s %s' % (get.status, httplib.responses[get.status]))
@@ -246,89 +303,70 @@
             return
         except GetException, e:
             self.progress_bar.EraseTermLine()
-            print "[%s] %s://%s:%s%s" % (str(e), self.url.scheme, self.host, self.port, self.url.path)
+            print "[%s] %s" % (str(e), self.url)
         except IOError, e:
             self.progress_bar.EraseTermLine()
-            print "[%s] %s://%s:%s%s" % (str(e.strerror), self.url.scheme, self.host, self.port, self.url.path)
+            print "[%s] %s" % (str(e.strerror), self.url)
         except Exception, e:
             from traceback import print_exc
             self.progress_bar.EraseTermLine()
-            print "[%s] %s://%s:%s%s %s" % (e.__class__.__name__, self.url.scheme, self.host, self.port, self.url.path, str(e))
+            print "[%s] %s %s" % (e.__class__.__name__, self.url, str(e))
             print_exc()
         self.http_status = 0
 
-class Download(object):
-    def __init__(self, raw_url):
-        self.raw_url = raw_url
-        self.url = urlparse.urlparse(raw_url)
-        self.hosts = args.hosts_list.split(',')
-        self.filename = 'index.html'
-        if '/' in self.url.path:
-            self.filename = unquote(self.url.path).rsplit('/', 1)[1]
-
-    def get_connection(self):
-        if self.url.scheme == 'http':
-            return httplib.HTTPConnection(self.url.hostname, self.url.port)
-        elif self.url.scheme == 'https':
-            return httplib.HTTPSConnection(self.url.hostname, self.url.port)
-        else:
-             raise Exception('Unknown protocol:%s' % self.url.scheme)
+class Download(HTTPConnection):
+    def __init__(self, url):
+        self.parse_url(url)
 
     def download(self):
         progress_bar = ProgressBar(threading.Lock())
         headers = {'User-Agent:': CLIENT_AGENT}
         head = self.get_connection()
-        head.request('HEAD', self.raw_url, headers=headers)
+        head.request('HEAD', self.path, headers=headers)
         head = head.getresponse()
 
         if head.status in (301, 302, 307):  # Temporary redirect
-            self.raw_url = head.msg.getheader('location')
-            self.url = urlparse.urlparse(self.raw_url)
-            print '[HTTP %i] Redirected to %s' % (head.status, self.raw_url)
+            url = head.msg.getheader('location')
+            self.parse_url(url)
+            print '[HTTP %i] Redirected to %s' % (head.status, self.url)
             head = self.get_connection()
-            head.request('HEAD', self.raw_url, headers=headers)
+            head.request('HEAD', self.url, headers=headers)
             head = head.getresponse()
 
         if head.status != 200:  # Ok
-            print 'Invalid status code %s %s' % (self.raw_url, head.status)
+            print 'Invalid status code %s %s' % (self.url, head.status)
             return
 
         length = head.getheader('Content-Length', None)
         if length is not None:
             length = int(length)
 
-        filename = 'index.html'
-        if '/' in self.url.path and not self.url.path.endswith('/'):
-            filename = unquote(self.url.path.rsplit('/', 1)[1])
-        hdd_file = HDDFile(filename, length)
+        hdd_file = HDDFile(self.filename, length)
 
         progress_bar.SetMax(length)
 
-        port = '80'
-        if self.url.scheme == 'https':
-            port = 443
-        if self.url.port is not None:
-            port = self.url.port
+        urls = [self.url]
+
+        # When the filesize is known, try downloading from the server pool
         if length is not None and length > IO_SIZE:
-            hosts = ['%s:%s' % (self.url.hostname, port)] + self.hosts
+            for url in args.urls_list.split(','):
+                urls += [url + self.path]
+            chunk_no = 0
         else:
-            hosts = ['%s:%s' % (self.url.hostname, port)]
+            chunk_no = None
 
         # Remove any dupplicated hosts but preserve host order
-        _hosts = []
-        for host in hosts:
-            if host != '' and host not in _hosts:
-                _hosts.append(host)
-        hosts = _hosts
-
-        chunk_no = 0
-        if len(hosts) == 1:
-            chunk_no = None
+        _urls = []
+        for url in urls:
+            if url != '' and url not in _urls:
+                _urls.append(url)
+        urls = _urls
+
         processes = []
-        for host_no, host in enumerate(hosts):
+        for url_no, url in enumerate(urls):
             if len(processes) == args.nconn:
                break
-            get = HTTPGet(hdd_file, host, self.raw_url, chunk_no, 1, length, progress_bar)
+            get = HTTPGet(hdd_file, url, chunk_no, 1, length, progress_bar)
             get.start()
             processes.append(get)
 
@@ -350,7 +388,7 @@
                             if chk_count < 1:
                                 chk_count = 1
                             chunk_no += chk_count
-                        get = HTTPGet(hdd_file, '%s:%s' % (get.host, get.port), get.raw_url, chk_no, chk_count, length, progress_bar)
+                        get = HTTPGet(hdd_file, get.url, chk_no, chk_count, length, progress_bar)
                         get.start()
                         _processes.append(get)
                     else:
@@ -368,7 +406,7 @@
     parser = argparse.ArgumentParser(description='Serve a file over http.')
     parser.add_argument('--config', '-c', metavar='configuration file', type=str, nargs=1,
                        help='Configuration file', default=None)
-    parser.add_argument('--hosts_list', '-H', metavar='host1,host2', type=str, nargs=1,
+    parser.add_argument('--urls_list', '-H', metavar='http://host1/,http://host2/', type=str, nargs=1,
                        help='Configuration file', default='127.0.0.1')
     parser.add_argument('--nconn', '-n', metavar='host1,host2', type=str, nargs=1,
                        help='Concurrent connections', default=5)
@@ -405,11 +443,11 @@
                 kw = {option: value}
                 parser.set_defaults(**kw)
             else:
-                print >> sys.stderr, 'Invalid option "%s"' % option
+                print >> sys.stderr, 'Invalid option "%s" in configuration file %s' % (option, args.config)
                 sys.exit(1)
         args = parser.parse_args()
 
-    for single_opt in ['config', 'nconn', 'bw_limit', 'hosts_list']:
+    for single_opt in ['config', 'nconn', 'bw_limit', 'urls_list']:
         val = getattr(args, single_opt)
         if isinstance(val, list):
             val = val[0]