
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

def cylinder(r,n):
    '''
    Returns the unit cylinder that corresponds to the curve r.
    INPUTS:  r - a vector of radii
             n - number of coordinates to return for each element in r

    OUTPUTS: x,y,z - coordinates of points
    '''

    # ensure that r is a column vector
    r = np.atleast_2d(r)
    r_rows,r_cols = r.shape

    if r_cols > r_rows:
        r = r.T

    # find points along x and y axes
    points  = np.linspace(0,2*np.pi,n+1)
    x = np.cos(points)*r
    y = np.sin(points)*r

    # find points along z axis
    rpoints = np.atleast_2d(np.linspace(0,1,len(r)))
    z = np.ones((1,n+1))*rpoints.T

    return x,y,z


t=np.arange(0,2*np.pi,np.pi/10)
X,Y,Z = cylinder(4*np.cos(t),50)
# set up a square figure
fig = plt.figure(figsize=plt.figaspect(1.0))
ax = fig.add_subplot(2, 2, 1, projection='3d')
ax.plot_wireframe(X, Y, np.sqrt(Z))
ax = fig.add_subplot(2, 2, 2, projection='3d')
ax.plot_wireframe(X, Y, np.square(Z))
ax = fig.add_subplot(2, 2, 3, projection='3d')
ax.plot_wireframe(X, Y, Z*5)
ax = fig.add_subplot(2, 2, 4, projection='3d')
ax.plot_wireframe(X, Y, Z)
plt.show()
