package pattern

import (
	"fmt"
	"go/ast"
	"go/token"
	"go/types"
	"reflect"
)

var astTypes = map[string]reflect.Type{
	"Ellipsis":       reflect.TypeOf(ast.Ellipsis{}),
	"RangeStmt":      reflect.TypeOf(ast.RangeStmt{}),
	"AssignStmt":     reflect.TypeOf(ast.AssignStmt{}),
	"IndexExpr":      reflect.TypeOf(ast.IndexExpr{}),
	"Ident":          reflect.TypeOf(ast.Ident{}),
	"ValueSpec":      reflect.TypeOf(ast.ValueSpec{}),
	"GenDecl":        reflect.TypeOf(ast.GenDecl{}),
	"BinaryExpr":     reflect.TypeOf(ast.BinaryExpr{}),
	"ForStmt":        reflect.TypeOf(ast.ForStmt{}),
	"ArrayType":      reflect.TypeOf(ast.ArrayType{}),
	"DeferStmt":      reflect.TypeOf(ast.DeferStmt{}),
	"MapType":        reflect.TypeOf(ast.MapType{}),
	"ReturnStmt":     reflect.TypeOf(ast.ReturnStmt{}),
	"SliceExpr":      reflect.TypeOf(ast.SliceExpr{}),
	"StarExpr":       reflect.TypeOf(ast.StarExpr{}),
	"UnaryExpr":      reflect.TypeOf(ast.UnaryExpr{}),
	"SendStmt":       reflect.TypeOf(ast.SendStmt{}),
	"SelectStmt":     reflect.TypeOf(ast.SelectStmt{}),
	"ImportSpec":     reflect.TypeOf(ast.ImportSpec{}),
	"IfStmt":         reflect.TypeOf(ast.IfStmt{}),
	"GoStmt":         reflect.TypeOf(ast.GoStmt{}),
	"Field":          reflect.TypeOf(ast.Field{}),
	"SelectorExpr":   reflect.TypeOf(ast.SelectorExpr{}),
	"StructType":     reflect.TypeOf(ast.StructType{}),
	"KeyValueExpr":   reflect.TypeOf(ast.KeyValueExpr{}),
	"FuncType":       reflect.TypeOf(ast.FuncType{}),
	"FuncLit":        reflect.TypeOf(ast.FuncLit{}),
	"FuncDecl":       reflect.TypeOf(ast.FuncDecl{}),
	"ChanType":       reflect.TypeOf(ast.ChanType{}),
	"CallExpr":       reflect.TypeOf(ast.CallExpr{}),
	"CaseClause":     reflect.TypeOf(ast.CaseClause{}),
	"CommClause":     reflect.TypeOf(ast.CommClause{}),
	"CompositeLit":   reflect.TypeOf(ast.CompositeLit{}),
	"EmptyStmt":      reflect.TypeOf(ast.EmptyStmt{}),
	"SwitchStmt":     reflect.TypeOf(ast.SwitchStmt{}),
	"TypeSwitchStmt": reflect.TypeOf(ast.TypeSwitchStmt{}),
	"TypeAssertExpr": reflect.TypeOf(ast.TypeAssertExpr{}),
	"TypeSpec":       reflect.TypeOf(ast.TypeSpec{}),
	"InterfaceType":  reflect.TypeOf(ast.InterfaceType{}),
	"BranchStmt":     reflect.TypeOf(ast.BranchStmt{}),
	"IncDecStmt":     reflect.TypeOf(ast.IncDecStmt{}),
	"BasicLit":       reflect.TypeOf(ast.BasicLit{}),
}

func ASTToNode(node interface{}) Node {
	switch node := node.(type) {
	case *ast.File:
		panic("cannot convert *ast.File to Node")
	case nil:
		return Nil{}
	case string:
		return String(node)
	case token.Token:
		return Token(node)
	case *ast.ExprStmt:
		return ASTToNode(node.X)
	case *ast.BlockStmt:
		if node == nil {
			return Nil{}
		}
		return ASTToNode(node.List)
	case *ast.FieldList:
		if node == nil {
			return Nil{}
		}
		return ASTToNode(node.List)
	case *ast.BasicLit:
		if node == nil {
			return Nil{}
		}
	case *ast.ParenExpr:
		return ASTToNode(node.X)
	}

	if node, ok := node.(ast.Node); ok {
		name := reflect.TypeOf(node).Elem().Name()
		T, ok := structNodes[name]
		if !ok {
			panic(fmt.Sprintf("internal error: unhandled type %T", node))
		}

		if reflect.ValueOf(node).IsNil() {
			return Nil{}
		}
		v := reflect.ValueOf(node).Elem()
		objs := make([]Node, T.NumField())
		for i := 0; i < T.NumField(); i++ {
			f := v.FieldByName(T.Field(i).Name)
			objs[i] = ASTToNode(f.Interface())
		}

		n, err := populateNode(name, objs, false)
		if err != nil {
			panic(fmt.Sprintf("internal error: %s", err))
		}
		return n
	}

	s := reflect.ValueOf(node)
	if s.Kind() == reflect.Slice {
		if s.Len() == 0 {
			return List{}
		}
		if s.Len() == 1 {
			return ASTToNode(s.Index(0).Interface())
		}

		tail := List{}
		for i := s.Len() - 1; i >= 0; i-- {
			head := ASTToNode(s.Index(i).Interface())
			l := List{
				Head: head,
				Tail: tail,
			}
			tail = l
		}
		return tail
	}

	panic(fmt.Sprintf("internal error: unhandled type %T", node))
}

func NodeToAST(node Node, state State) interface{} {
	switch node := node.(type) {
	case Binding:
		v, ok := state[node.Name]
		if !ok {
			// really we want to return an error here
			panic("XXX")
		}
		switch v := v.(type) {
		case types.Object:
			return &ast.Ident{Name: v.Name()}
		default:
			return v
		}
	case Builtin, Any, Object, Function, Not, Or:
		panic("XXX")
	case List:
		if (node == List{}) {
			return []ast.Node{}
		}
		x := []ast.Node{NodeToAST(node.Head, state).(ast.Node)}
		x = append(x, NodeToAST(node.Tail, state).([]ast.Node)...)
		return x
	case Token:
		return token.Token(node)
	case String:
		return string(node)
	case Nil:
		return nil
	}

	name := reflect.TypeOf(node).Name()
	T, ok := astTypes[name]
	if !ok {
		panic(fmt.Sprintf("internal error: unhandled type %T", node))
	}
	v := reflect.ValueOf(node)
	out := reflect.New(T)
	for i := 0; i < T.NumField(); i++ {
		fNode := v.FieldByName(T.Field(i).Name)
		if (fNode == reflect.Value{}) {
			continue
		}
		fAST := out.Elem().FieldByName(T.Field(i).Name)
		switch fAST.Type().Kind() {
		case reflect.Slice:
			c := reflect.ValueOf(NodeToAST(fNode.Interface().(Node), state))
			if c.Kind() != reflect.Slice {
				// it's a single node in the pattern, we have to wrap
				// it in a slice
				slice := reflect.MakeSlice(fAST.Type(), 1, 1)
				slice.Index(0).Set(c)
				c = slice
			}
			switch fAST.Interface().(type) {
			case []ast.Node:
				switch cc := c.Interface().(type) {
				case []ast.Node:
					fAST.Set(c)
				case []ast.Expr:
					var slice []ast.Node
					for _, el := range cc {
						slice = append(slice, el)
					}
					fAST.Set(reflect.ValueOf(slice))
				default:
					panic("XXX")
				}
			case []ast.Expr:
				switch cc := c.Interface().(type) {
				case []ast.Node:
					var slice []ast.Expr
					for _, el := range cc {
						slice = append(slice, el.(ast.Expr))
					}
					fAST.Set(reflect.ValueOf(slice))
				case []ast.Expr:
					fAST.Set(c)
				default:
					panic("XXX")
				}
			default:
				panic("XXX")
			}
		case reflect.Int:
			c := reflect.ValueOf(NodeToAST(fNode.Interface().(Node), state))
			switch c.Kind() {
			case reflect.String:
				tok, ok := tokensByString[c.Interface().(string)]
				if !ok {
					// really we want to return an error here
					panic("XXX")
				}
				fAST.SetInt(int64(tok))
			case reflect.Int:
				fAST.Set(c)
			default:
				panic(fmt.Sprintf("internal error: unexpected kind %s", c.Kind()))
			}
		default:
			r := NodeToAST(fNode.Interface().(Node), state)
			if r != nil {
				fAST.Set(reflect.ValueOf(r))
			}
		}
	}

	return out.Interface().(ast.Node)
}