// Copyright 2013 Google Inc.  All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pretty

import (
	"encoding"
	"fmt"
	"reflect"
	"sort"
)

func isZeroVal(val reflect.Value) bool {
	if !val.CanInterface() {
		return false
	}
	z := reflect.Zero(val.Type()).Interface()
	return reflect.DeepEqual(val.Interface(), z)
}

// pointerTracker is a helper for tracking pointer chasing to detect cycles.
type pointerTracker struct {
	addrs map[uintptr]int // addr[address] = seen count

	lastID int
	ids    map[uintptr]int // ids[address] = id
}

// track tracks following a reference (pointer, slice, map, etc).  Every call to
// track should be paired with a call to untrack.
func (p *pointerTracker) track(ptr uintptr) {
	if p.addrs == nil {
		p.addrs = make(map[uintptr]int)
	}
	p.addrs[ptr]++
}

// untrack registers that we have backtracked over the reference to the pointer.
func (p *pointerTracker) untrack(ptr uintptr) {
	p.addrs[ptr]--
	if p.addrs[ptr] == 0 {
		delete(p.addrs, ptr)
	}
}

// seen returns whether the pointer was previously seen along this path.
func (p *pointerTracker) seen(ptr uintptr) bool {
	_, ok := p.addrs[ptr]
	return ok
}

// keep allocates an ID for the given address and returns it.
func (p *pointerTracker) keep(ptr uintptr) int {
	if p.ids == nil {
		p.ids = make(map[uintptr]int)
	}
	if _, ok := p.ids[ptr]; !ok {
		p.lastID++
		p.ids[ptr] = p.lastID
	}
	return p.ids[ptr]
}

// id returns the ID for the given address.
func (p *pointerTracker) id(ptr uintptr) (int, bool) {
	if p.ids == nil {
		p.ids = make(map[uintptr]int)
	}
	id, ok := p.ids[ptr]
	return id, ok
}

// reflector adds local state to the recursive reflection logic.
type reflector struct {
	*Config
	*pointerTracker
}

// follow handles following a possiblly-recursive reference to the given value
// from the given ptr address.
func (r *reflector) follow(ptr uintptr, val reflect.Value) node {
	if r.pointerTracker == nil {
		// Tracking disabled
		return r.val2node(val)
	}

	// If a parent already followed this, emit a reference marker
	if r.seen(ptr) {
		id := r.keep(ptr)
		return ref{id}
	}

	// Track the pointer we're following while on this recursive branch
	r.track(ptr)
	defer r.untrack(ptr)
	n := r.val2node(val)

	// If the recursion used this ptr, wrap it with a target marker
	if id, ok := r.id(ptr); ok {
		return target{id, n}
	}

	// Otherwise, return the node unadulterated
	return n
}

func (r *reflector) val2node(val reflect.Value) node {
	if !val.IsValid() {
		return rawVal("nil")
	}

	if val.CanInterface() {
		v := val.Interface()
		if formatter, ok := r.Formatter[val.Type()]; ok {
			if formatter != nil {
				res := reflect.ValueOf(formatter).Call([]reflect.Value{val})
				return rawVal(res[0].Interface().(string))
			}
		} else {
			if s, ok := v.(fmt.Stringer); ok && r.PrintStringers {
				return stringVal(s.String())
			}
			if t, ok := v.(encoding.TextMarshaler); ok && r.PrintTextMarshalers {
				if raw, err := t.MarshalText(); err == nil { // if NOT an error
					return stringVal(string(raw))
				}
			}
		}
	}

	switch kind := val.Kind(); kind {
	case reflect.Ptr:
		if val.IsNil() {
			return rawVal("nil")
		}
		return r.follow(val.Pointer(), val.Elem())
	case reflect.Interface:
		if val.IsNil() {
			return rawVal("nil")
		}
		return r.val2node(val.Elem())
	case reflect.String:
		return stringVal(val.String())
	case reflect.Slice:
		n := list{}
		length := val.Len()
		ptr := val.Pointer()
		for i := 0; i < length; i++ {
			n = append(n, r.follow(ptr, val.Index(i)))
		}
		return n
	case reflect.Array:
		n := list{}
		length := val.Len()
		for i := 0; i < length; i++ {
			n = append(n, r.val2node(val.Index(i)))
		}
		return n
	case reflect.Map:
		// Extract the keys and sort them for stable iteration
		keys := val.MapKeys()
		pairs := make([]mapPair, 0, len(keys))
		for _, key := range keys {
			pairs = append(pairs, mapPair{
				key:   new(formatter).compactString(r.val2node(key)), // can't be cyclic
				value: val.MapIndex(key),
			})
		}
		sort.Sort(byKey(pairs))

		// Process the keys into the final representation
		ptr, n := val.Pointer(), keyvals{}
		for _, pair := range pairs {
			n = append(n, keyval{
				key: pair.key,
				val: r.follow(ptr, pair.value),
			})
		}
		return n
	case reflect.Struct:
		n := keyvals{}
		typ := val.Type()
		fields := typ.NumField()
		for i := 0; i < fields; i++ {
			sf := typ.Field(i)
			if !r.IncludeUnexported && sf.PkgPath != "" {
				continue
			}
			field := val.Field(i)
			if r.SkipZeroFields && isZeroVal(field) {
				continue
			}
			n = append(n, keyval{sf.Name, r.val2node(field)})
		}
		return n
	case reflect.Bool:
		if val.Bool() {
			return rawVal("true")
		}
		return rawVal("false")
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		return rawVal(fmt.Sprintf("%d", val.Int()))
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		return rawVal(fmt.Sprintf("%d", val.Uint()))
	case reflect.Uintptr:
		return rawVal(fmt.Sprintf("0x%X", val.Uint()))
	case reflect.Float32, reflect.Float64:
		return rawVal(fmt.Sprintf("%v", val.Float()))
	case reflect.Complex64, reflect.Complex128:
		return rawVal(fmt.Sprintf("%v", val.Complex()))
	}

	// Fall back to the default %#v if we can
	if val.CanInterface() {
		return rawVal(fmt.Sprintf("%#v", val.Interface()))
	}

	return rawVal(val.String())
}

type mapPair struct {
	key   string
	value reflect.Value
}

type byKey []mapPair

func (v byKey) Len() int           { return len(v) }
func (v byKey) Swap(i, j int)      { v[i], v[j] = v[j], v[i] }
func (v byKey) Less(i, j int) bool { return v[i].key < v[j].key }