# SUMMARY:      createstreamnetwork.py
# USAGE:        Main code for python version createstreamnetwork
# ORG:          Pacific Northwest National Laboratory
# AUTHOR:       Zhuoran Duan
# E-MAIL:       zhuoran.duan@pnnl.gov
# ORIG-DATE:    Apr-2017
# DESCRIPTION:  Python version of original createstreamnetwork aml
# DESCRIP-END.
# COMMENTS:     This python script is created based on original 
#		AML scripts createstreamnetwork.aml as part of DHSVM
#
# Last Change: 2017-08-10 

# -------------------------------------------------------------
#	Import system modules
# -------------------------------------------------------------

import arcpy
from arcpy import env
from arcpy.sa import *
import sys
import os
import numpy as np
import shutil

# -------------------------------------------------------------------#
# --------------------------- WorkSpace  ----------------------------#
# -------------------------------------------------------------------#
env.workspace = "C:/Users/jwon/Desktop/kitsap_clean/arcpy/kitsap8/"
path          = "C:/Users/jwon/Desktop/kitsap_clean/arcpy/kitsap8/"
##-------------------------------------------------------------------#
###########           Setup Input          
#-------------------------------------------------------------------#
elev = "dem"			# Name of DEM grid
wshd = "mask"			# Name of mask file
soil = "soil"			# Name of soil file
veg = "veg"				# Name of veg file
soild = "soild"		# Name of soil depth file
stream = "stream"		# Name of stream arc file

runMask = False		# Option to remake mask file
runSoil = False		# Option to remake soil bin
runVeg = False			# Option to remake veg bin

sourceArea =  100000			# Min source area to initiate stream (sq. meter)
soil_min = .76					# Minimum Soil Depth (meter)
soil_max =   5					# Maximum Soil Depth (meter)

# elev = path + elev
# wshd = path + wshd
# soil = path + soil
# veg = path + veg
# soild = path + soild
odb = path + "/output.gdb/"

###-------------------------------------------------------------------#
###------------------------   End of Edits ---------------------------#
###---------------  Spatial Analysis License Required   --------------#
###-------------------------------------------------------------------#
print("Create DHSVM input start")

# Check for valid inputs
if not arcpy.Exists(elev):
	print('DEM file found')
	sys.exit()


# Set the cell size environment using a raster dataset.
env.outputCoordinateSystem = elev
env.extent = elev
env.cellSize = elev
arcpy.env.overwriteOutput = True
arcpy.CheckOutExtension("Spatial")


## TODO
## Fill dem

def file_clear(ifile, file_desc=None):
    if arcpy.Exists(ifile):
        if file_desc:
            print("\t{} already exists. Deleting existing file and recreating".format(file_desc))
        arcpy.Delete_management(ifile)
    return


# Create a mask of full extension for masking purposes
file_clear("fullMask")
fullMask = CreateConstantRaster(1, "INTEGER", env.cellSize, env.extent)
fullMask.save("fullMask")

# Create Flow Direction
print("Creating Flow Direction")
file_clear("flow_dir")
flowdir = FlowDirection(elev, "", "")
flowdir.save("flow_dir")

# Get Mask
if (runMask) | (not arcpy.Exists(wshd)):
    print("Recreating mask")
    wshd = Watershed(flowdir, wshd, "")
env.mask = wshd


# Create Flow Accumulation
print("Creating Flow Accumulation")
file_clear("flow_acc")
flowacc = FlowAccumulation(flowdir, "", "INTEGER")
flowacc.save("flow_acc")

detlaxResult = arcpy.GetRasterProperties_management(elev,"CELLSIZEX")
deltax = detlaxResult.getOutput(0)
sourcepix = sourceArea / (float(deltax) * float(deltax))
temp = Con(flowacc > sourcepix, 1)
rivg = temp * wshd

if not arcpy.Exists("/output.gdb"):
    print('Creating geodatabase')
    arcpy.CreateFileGDB_management(path, odb)

print('Creating stream shapefile')
streamnet = odb + stream
file_clear(streamnet, "Stream file")
streamlink = StreamLink(rivg, flowdir)
StreamToFeature(streamlink, flowdir, streamnet, "NO_SIMPLIFY")
streamlink.save("streamshape")

# Find contributing area to each cell along stream network
file_clear(rivg)
arcpy.PolylineToRaster_conversion (streamnet, "arcid", rivg, "MAXIMUM_LENGTH","NONE", env.cellSize)
local=Watershed(flowdir, rivg, "VALUE")

file_clear("local")
local.save("local")
arcpy.AddField_management (streamnet, "local", "LONG")
arcpy.JoinField_management(streamnet, "arcid", local, "Value", "Count")

fields=['COUNT', 'local']
with arcpy.UpdateCursor(streamnet, fields) as cursor:
    for row in cursor:
        if row[0] is None:
            row[1] = 0
        else:
            row[1]=row[0]
        cursor.updateRow(row)
del row

arcpy.DeleteField_management(streamnet, "COUNT")

print('local file created')
arcpy.Delete_management(rivg)

###-------------------------------------------------------------------#
### create node point coverage and find elevations
###-------------------------------------------------------------------#

node_st = odb + "nodestart"
file_clear(node_st, "Start node file")
arcpy.FeatureVerticesToPoints_management(streamnet, node_st, "START")

node_ed = odb + "nodeend"
file_clear(node_ed)
arcpy.FeatureVerticesToPoints_management(streamnet, node_ed, "END")

# Find Contributing area at end of each arc
env.mask = fullMask
tmpacc = Con(IsNull("flow_acc") == 1,int(flowacc.maximum),"flow_acc")

ExtractMultiValuesToPoints(node_st, [[tmpacc, "MAXGRID"]])

arcpy.AddField_management (streamnet, "downarc", "LONG")

ExtractMultiValuesToPoints(node_st, [[elev, "SELEV"]], "NONE")
elevras = Raster(elev)
elevras = Raster(elev)
tmpelev = Con(IsNull(elevras) == 1,int(elevras.minimum),elev)
ExtractMultiValuesToPoints(node_ed, [[tmpelev, "EELEV"]], "NONE")

env.mask = wshd
arcpy.JoinField_management(streamnet, "arcid", node_st, "arcid", "SELEV")
arcpy.JoinField_management(streamnet, "arcid", node_ed, "arcid", "EELEV")
arcpy.JoinField_management(streamnet, "arcid", node_st, "arcid", "MAXGRID")

arcpy.AddField_management(streamnet, "uparc", "LONG")
arcpy.AddField_management(streamnet, "dz", "FLOAT", 12, 3)
arcpy.AddField_management(streamnet, "slope", "FLOAT", 12, 5)
arcpy.AddField_management(streamnet, "meanmsq",  "FLOAT")
arcpy.AddField_management(streamnet, "segorder", "LONG")
arcpy.AddField_management(streamnet, "chanclass", "SHORT", 8, "")
arcpy.AddField_management(streamnet, "hyddepth", "FLOAT", 8, 2)
arcpy.AddField_management(streamnet, "hydwidth", "FLOAT", 8, 2)
arcpy.AddField_management(streamnet, "effwidth", "FLOAT", 8, 2)
arcpy.AddField_management(streamnet, "effdepth", "FLOAT", 8, 2)

print('Calculating Slope of channel segment')
arcpy.CalculateField_management (streamnet, "dz", "abs(!SELEV! - !EELEV!)", "PYTHON_9.3", "")

expression = "clacSlope(float(!dz!),float(!Shape_Length!))"
codeblock = """def clacSlope(dz,length):
    if (dz/length)>0.00001:
        return dz/length
    else:
        return 0.00001"""

arcpy.CalculateField_management (streamnet, "slope", expression, "PYTHON_9.3", codeblock)
arcpy.CalculateField_management (streamnet, "segorder", "0", "PYTHON_9.3")
arcpy.CalculateField_management (streamnet, "uparc", "0", "PYTHON_9.3")
arcpy.CalculateField_management (streamnet, "downarc", "-1", "PYTHON_9.3")
arcpy.CalculateField_management (streamnet, "meanmsq", "0.0", "PYTHON_9.3")


#-------------------------------------------------------------------#
# Calculate Segment Order       
#-------------------------------------------------------------------#
arr = arcpy.da.TableToNumPyArray(streamnet,('from_node','to_node','segorder','local','MAXGRID','meanmsq','uparc','downarc','arcid'))

print('Looking for downstream arc')
# Calculate downstream arc 
for jj, ii in enumerate(arr['to_node']):
    arr2=[]
    for i, j in enumerate(arr['from_node']):
        if j == ii:
            arr['downarc'][jj]=arr['arcid'][i]


print('Looking for upstream arc')
# Calculate upstream arc based on max conributing area
for jj, ii in enumerate(arr['from_node']):
    arr2=[]
    for i, j in enumerate(arr['to_node']):
        if j == ii:
            arr2=np.append(arr2,i)
    #print arr2
    if not len(arr2):
        arr['uparc'][jj]=-1
        arr['segorder'][jj]=1
        arr['meanmsq'][jj]=arr['local'][jj]/ 2 * float(env.cellSize) * float(env.cellSize)
    else:
        arr3=arr2.astype(int)
        loc=np.argmax(arr['MAXGRID'][arr3]+arr['local'][arr3])
        arr['uparc'][jj]=arr['arcid'][arr3[loc]]
        arr['meanmsq'][jj]=(arr['MAXGRID'][jj]+arr['local'][jj]/ 2) * float(env.cellSize) * float(env.cellSize)

# Calculate segorder
order=1
a=99
print('Calculating segment order')
while a > 0:
    a=0
    for jj, ii in enumerate(arr['segorder']):
        if ii==order:
            a+=1
            for i, j in enumerate(arr['arcid']):
                if j == arr['downarc'][jj]:
                    arr['segorder'][i]=max(order+1,arr['segorder'][i])
    order+=1


# Append the array to an existing table
arcpy.da.ExtendTable(streamnet,  "arcid", arr, "arcid", append_only=False)

#-------------------------------------------------------------------#
###########          Chanel Hydraulic Classes          
#-------------------------------------------------------------------#
from channelclass import channelclassfun
print('Assign channel class')
channelclassfun(streamnet)

#-------------------------------------------------------------------#
###########           SOIL DEPTH          
#-------------------------------------------------------------------#
from soildepthscript import soildepthfun

print("Creating soil depth map")
file_clear(soild, "Soil depth file")

soildepthfun("flow_acc", elev, soil_min, soil_max, soild)

#-------------------------------------------------------------------#
#######   Shallowest Soil Detpth for channel segment 
#-------------------------------------------------------------------#
stmlineras = "stmlineras"
file_clear(stmlineras)
arcpy.PolylineToRaster_conversion (streamnet, "arcid", stmlineras, "", "", env.cellSize)

soild_table="soild_zonal"
file_clear(soild_table, "Zonal statistics table")
ZonalStatisticsAsTable (stmlineras, "VALUE", soild, soild_table, "DATA", "MINIMUM")
arcpy.JoinField_management(streamnet, "arcid", soild_table , "VALUE","MIN")

########################################
# When channel lenght is smaller than grid cell size, there're chances that the segment will not be
# represented in the stream to raster conversion, therefore no seg-depth value generated for the arc.
# To avoid such problem, find min soil depth at all feature points of an arc to replace null values

all_nodes = odb + "all_nodes"
arcpy.FeatureVerticesToPoints_management(streamnet, all_nodes, "ALL")
ExtractMultiValuesToPoints(all_nodes, [[soild, "soildepth"]], "BILINEAR")

arcpy.Statistics_analysis(all_nodes, "all_nodes_statistics", "soildepth MIN", "arcid")
arcpy.JoinField_management(streamnet, "arcid", "all_nodes_statistics", "arcid", "MIN_soildepth")
arcpy.AddField_management (streamnet, "segdepth", "FLOAT", 8, 2)

fields=['MIN_soildepth', 'MIN', 'segdepth']
with arcpy.da.UpdateCursor(streamnet, fields) as cursor:
    for row in cursor:
        if row[1] is None:
            row[2]=row[0]
        elif (row[0] < row[1]):
            row[2]=row[0]
        elif (row[0] > row[1]):
            row[2]=row[1]
        else:
            row[2]=row[0]
        cursor.updateRow(row)

arcpy.DeleteField_management(streamnet, "MIN")
arcpy.DeleteField_management(streamnet, "MIN_soild")
arcpy.Delete_management(all_nodes)
arcpy.Delete_management("all_nodes_statistics")

arcpy.CalculateField_management (streamnet, "effdepth", "0.95*float(!segdepth!)", "PYTHON_9.3","")

#-------------------------------------------------------------------#
#######     Create stream network file for DHSVM 
#-------------------------------------------------------------------#
if os.path.exists("stream.network.dat"):
    print('stream.network.dat already exists, delete and create new')
    os.remove("stream.network.dat")
    print('stream.network.dat successfully deleted')

# Select certain fields from table to a new numpy array
print('creating new stream.network.dat file')
arr_export=arcpy.da.TableToNumPyArray(streamnet,('arcid','segorder','slope','Shape_Length','chanclass','downarc'))
np.savetxt(path+'stream.network.dat', arr_export,fmt="%5d %3d %11.5f %17.5f %3d %7d")

#-------------------------------------------------------------------#
#######     Create stream map file for DHSVM
#-------------------------------------------------------------------#
print('running wshdslope')
from wshdslope import wshdslopefun
slope = "gridslope_4d"
aspect = "gridaspect"

wshdslopefun(elev, slope, aspect, flowacc)
from rowcolmap2 import rowcolmapfun
print('running rowcolmap')

#rowcolmapfun(elev,aspect,slope,"maskras")
rcpoly = "rowcolpoly.shp"
rowcolmapfun(elev, rcpoly, fullMask)

#sys.exit()

print('creating streammap')
outcover = odb + "outcover"
file_clear(outcover)
arcpy.Intersect_analysis([streamnet,"rowcolpoly.shp"], outcover, "", "", "")

from roadaspect import roadaspectfun
print('running roadaspect')
roadaspectfun(outcover)

from streammapfile import streammapfilefun
print('generating stream map file')
streammapfilefun(outcover, path)

#-------------------------------------------------------------------#
#######     Create ascii files
#-------------------------------------------------------------------#

apath = path + "/ascii/"
if os.path.exists(apath):
    shutil.rmtree(apath)

os.mkdir(apath)
arcpy.RasterToASCII_conversion(soild, "{}/soild.asc".format(apath))




#-------------------------------------------------------------------#
#######     Clean Up Tmp files 
#-------------------------------------------------------------------#

print('File clean-up')
#arcpy.Delete_management(flowdir)
#arcpy.Delete_management(flowacc)
arcpy.Delete_management(slope)
arcpy.Delete_management(aspect)
arcpy.Delete_management(tmpelev)
arcpy.Delete_management(fullMask)
arcpy.Delete_management(local)
arcpy.Delete_management(soild_table)
#arcpy.Delete_management(outcover)
#arcpy.Delete_management(rcpoly)
arcpy.Delete_management("streaml_time1")
arcpy.Delete_management("ifthe_ras")

arcpy.Delete_management(node_st)
arcpy.Delete_management(node_ed)
arcpy.Delete_management(tmpacc)
arcpy.Delete_management(stmlineras)

print('Complete')







