changeset 878:f39882ce0e89

Do bitonic sort for arbitrary array sizes.
author Martin Geisler <mg@daimi.au.dk>
date Fri, 08 Aug 2008 10:27:24 +0200
parents f12a36276d56
children 67a0628c2f83
files apps/sort.py
diffstat 1 files changed, 8 insertions(+), 11 deletions(-) [+]
line wrap: on
line diff
--- a/apps/sort.py	Fri Aug 08 10:23:14 2008 +0200
+++ b/apps/sort.py	Fri Aug 08 10:27:24 2008 +0200
@@ -41,7 +41,7 @@
 # Give a player configuration file as a command line argument or run
 # the example with '--help' for help with the command line options.
 
-from math import log
+from math import log, floor
 from optparse import OptionParser
 from twisted.internet import reactor
 
@@ -56,7 +56,7 @@
 parser.add_option("--modulus",
                   help="lower limit for modulus (can be an expression)")
 parser.add_option("-s", "--size", type="int",
-                  help="array size (must be power of 2)")
+                  help="array size")
 parser.add_option("-m", "--max", type="int",
                   help="maximum size of array numbers")
 parser.set_defaults(modulus=2**65, size=8, max=100)
@@ -68,10 +68,6 @@
 if len(args) == 0:
     parser.error("you must specify a config file")
 
-log_s = log(options.size, 2)
-if int(log_s) != log_s:
-    parser.error("the array size must be a power of 2")
-
 Zp = GF(find_prime(options.modulus, blum=True))
 
 class Protocol:
@@ -113,17 +109,18 @@
         def bitonic_sort(low, n, ascending):
             if n > 1:
                 m = n // 2
-                bitonic_sort(low, m, ascending=True)
-                bitonic_sort(low + m, m, ascending=False)
+                bitonic_sort(low, m, ascending=not ascending)
+                bitonic_sort(low + m, n - m, ascending)
                 bitonic_merge(low, n, ascending)
 
         def bitonic_merge(low, n, ascending):
             if n > 1:
-                m = n // 2
-                for i in range(low, low + m):
+                # Choose m as the greatest power of 2 less than n.
+                m = 2**int(floor(log(n-1, 2)))
+                for i in range(low, low + n - m):
                     compare(i, i+m, ascending)
                 bitonic_merge(low, m, ascending)
-                bitonic_merge(low + m, m, ascending)
+                bitonic_merge(low + m, n - m, ascending)
 
         def compare(i, j, ascending):