将递归转化成迭代的通用技术
从理论上讲,只要允许使用栈,所有的递归程序都可以转化成迭代。
但是并非所有递归都必须用栈,不用堆栈也可以转化成迭代的,大致有两类
- 尾递归:可以通过简单的变换,让递归作为最后一条语句,并且仅此一个递归调用。
1 2 3 4 5 6 7 8 9 10 11 |
// recursive int fac1(int n) { if (n <= 0) return 1; return n * fac1(n-1); } // iterative int fac2(int n) { int i = 1, y = 1; for (; i <= n; ++i) y *= i; return y; } |
- 自顶向下->自底向上:对程序的结构有深刻理解后,自底向上计算,比如 fibnacci 数列的递归->迭代转化。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
// recursive, top-down int fib1(int n) { if (n <= 1) return 1; return fib1(n-1) + fib1(n-2); } // iterative, down-top int fib2(int n) { int f0 = 1, f1 = 1, i; for (i = 2; i <= n; ++i) { int f2 = f1 + f0; f0 = f1; f1 = f2; } return f1; } |
对于非尾递归,就必须使用堆栈。可以简单生硬地使用堆栈进行转化:把函数调用和返回的地方翻译成汇编代码,然后把对硬件 stack 的 push, pop 操作转化成对私有 stack 的 push, pop ,这其中需要特别注意的是对返回地址的 push/pop,对应的硬件指令一般是 call/ret。使用私有 stack 有两个好处:
- 可以省去公用局部变量,也就是在任何一次递归调用中都完全相同的函数参数,再加上从这些参数计算出来的局部变量。
- 如果需要得到当前的递归深度,可以从私有 stack 直接拿到,而用递归一般需要一个单独的 depth 变量,然后每次递归调用加 1。
我们把私有 stack 元素称为 Frame,那么 Frame 中必须包含以下信息:
- 返回地址(对应于每个递归调用的下一条语句的地址)
- 对每次递归调用都不同的参数
通过实际操作,我发现,有一类递归的 Frame 可以省去返回地址!所以,这里又分为两种情况:
- Frame 中可以省去返回地址的递归:仅有两个递归调用,并且其中有一个是尾递归。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
// here used a function 'partition', but don't implement it tempalte<class RandIter> void QuickSort1(RandIter beg, RandIter end) { if (end - beg <= 1) return; RandIter pos = partition(beg, end); QuickSort1(beg, pos); QuickSort1(pos + 1, end); } tempalte<class RandIter> void QuickSort2(RandIter beg, RandIter end) { std::stack<std::pair<RandIter> > stk; stk.push({beg, end}); while (!stk.empty()) { std::pair<RandIter, RandIter> ii = stk.top(); stk.pop(); if (ii.second - ii.first) > 1) { RandIter pos = partition(beg, end); stk.push({ii.first, pos}); stk.push({pos + 1, ii.second}); } } } |
- Frame 中必须包含返回地址的递归,这个比较复杂,所以我写了个完整的示例:
- 以MergeSort为例,因为 MergeSort 是个后序过程,两个递归调用中没有任何一个是尾递归
- MergeSort3 使用了 GCC 的 Label As Value 特性,只能在 GCC 兼容的编译器中使用
- 单纯对于这个实例来说,返回地址其实只有两种,返回地址为 0 的情况可以通过判断私有栈(varname=stk)是否为空,stk为空时等效于 retaddr == 0。如果要精益求精,一般情况下指针的最低位总是0,可以把这个标志保存在指针的最低位,当然,如此的话就无法对 sizeof(T)==1 的对象如 char 进行排序了。
-
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202#include <stdio.h>#include <string.h># if 1#include <stack>#include <vector>template<class T>class MyStack : public std::stack<T, std::vector<T> >{};#elsetemplate<class T>class MyStack {union {char* a;T* p;};int n, t;public:explicit MyStack(int n=128) {this->n = n;this->t = 0;a = new char[n*sizeof(T)];}~MyStack() {while (t > 0)pop();delete[] a;}void swap(MyStack<T>& y) {char* q = y.a; y.a = a; a = q;int z;z = y.n; y.n = n; n = z;z = y.t; y.t = t; t = z;}T& top() const {return p[t-1];}void pop() {--t;p[t].~T();}void push(const T& x) {x.print(); // debugp[t] = x;++t;}int size() const { return t; }bool empty() const { return 0 == t; }bool full() const { return n == t; }};#endiftemplate<class T>struct Frame {static T* base;T *beg, *tmp;int len;int retaddr;Frame(T* beg, T* tmp, int len, int retaddr): beg(beg), tmp(tmp), len(len), retaddr(retaddr){}void print() const { // for debugprintf("%4d %4d %d/n", int(beg-base), len, retaddr);}};template<class T> T* Frame<T>::base;#define TOP(field) stk.top().fieldtemplate<class T>bool issorted(const T* a, int n){for (int i = 1; i < n; ++i) {if (a[i-1] > a[i]) return false;}return true;}template<class T>void mymerge(const T* a, int la, const T* b, int lb, T* c) {int i = 0, j = 0, k = 0;for (; i < la && j < lb; ++k) {if (b[j] < a[i])c[k] = b[j], ++j;elsec[k] = a[i], ++i;}for (; i < la; ++i, ++k) c[k] = a[i];for (; j < lb; ++j, ++k) c[k] = b[j];}template<class T>void MergeSort1(T* beg, T* tmp, int len) {if (len > 1) {int mid = len / 2;MergeSort1(beg , tmp , mid);MergeSort1(beg+mid, tmp+mid, len-mid);mymerge(tmp, mid, tmp+mid, len-mid, beg);memcpy(tmp, beg, sizeof(T)*len);}else*tmp = *beg;}template<class T>void MergeSort2(T* beg0, T* tmp0, int len0) {int mid;int cnt = 0;Frame<T>::base = beg0;MyStack<Frame<T> > stk;stk.push(Frame<T>(beg0, tmp0, len0, 0));while (true) {++cnt;if (TOP(len) > 1) {mid = TOP(len) / 2;stk.push(Frame<T>(TOP(beg), TOP(tmp), mid, 1));continue;L1:mid = TOP(len) / 2;stk.push(Frame<T>(TOP(beg)+mid, TOP(tmp)+mid, TOP(len)-mid, 2));continue;L2:mid = TOP(len) / 2;mymerge(TOP(tmp), mid, TOP(tmp)+mid, TOP(len)-mid, TOP(beg));memcpy(TOP(tmp), TOP(beg), sizeof(T)*TOP(len));} else*TOP(tmp) = *TOP(beg);int retaddr0 = TOP(retaddr);stk.pop();switch (retaddr0) {case 0: return;case 1: goto L1;case 2: goto L2;}}}// This Implementation Use GCC's goto saved label value// Very similiar with recursive versiontemplate<class T>void MergeSort3(T* beg0, T* tmp0, int len0) {MyEntry:int mid;int retaddr;Frame<T>::base = beg0;MyStack<Frame<T> > stk;stk.push(Frame<T>(beg0, tmp0, len0, 0));#define Cat1(a,b) a##b#define Cat(a,b) Cat1(a,b)#define HereLabel() Cat(HereLable_, __LINE__)#define RecursiveCall(beg, tmp, len) /stk.push(Frame<T>(beg, tmp, len, (char*)&&HereLabel() - (char*)&&MyEntry)); /continue; /HereLabel():;//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~// retaddr == 0 是最外层的递归调用,// 只要到达这一层时 retaddr 才为 0,// 此时就可以返回了#define MyReturn /retaddr = TOP(retaddr); /stk.pop(); /if (0 == retaddr) { /return; /} /goto *((char*)&&MyEntry + retaddr);//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~while (true) {if (TOP(len) > 1) {mid = TOP(len) / 2;RecursiveCall(TOP(beg), TOP(tmp), mid);mid = TOP(len) / 2;RecursiveCall(TOP(beg)+mid, TOP(tmp)+mid, TOP(len)-mid);mid = TOP(len) / 2;mymerge(TOP(tmp), mid, TOP(tmp)+mid, TOP(len)-mid, TOP(beg));memcpy(TOP(tmp), TOP(beg), sizeof(T)*TOP(len));} else*TOP(tmp) = *TOP(beg);MyReturn;}}template<class T>void MergeSortDriver(T* beg, int len, void (*mf)(T* beg_, T* tmp_, int len_)){T* tmp = new T[len];(*mf)(beg, tmp, len);delete[] tmp;}#define test(a,n,mf) /memcpy(a, b, sizeof(a[0])*n); /MergeSortDriver(a, n, &mf); /printf("sort by %s:", #mf); /for (i = 0; i < n; ++i) printf("% ld", a[i]); /printf("/n");int main(int argc, char* argv[]){int n = argc - 1;int i;long* a = new long[n];long* b = new long[n];for (i = 0; i < n; ++i)b[i] = strtol(argv[i+1], NULL, 10);test(a, n, MergeSort1);test(a, n, MergeSort2);test(a, n, MergeSort3);printf("All Successed/n");delete[] a;delete[] b;return 0;}